# Audio Enhancement using Demucs

This notebook demonstrates how to enhance audio using the `Demucs` library. We will:
1. Load and preprocess an audio file.
2. Apply the Demucs model to enhance the audio.
3. Visualize and output the enhanced audio.

## Explanation
Demucs is a deep learning model designed for music source separation, but it can also be used for general audio enhancement. It is particularly effective at isolating vocals from background music. This can be useful for speech-to-text processing when the goal is to focus on the spoken words and minimize the impact of background sounds. However, it may not be as effective for non-musical background noise.


## Step 1: Install Requirements

Install torch, numpy, matplotlib, soundfile and demucs from the actively
maintained github repository.

In [None]:
# Setup installers
commands = [
    ("PIP_ROOT_USER_ACTION=ignore pip install -q torch", "Install torch"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q numpy", "Install numpy"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q matplotlib", "Install matplotlib"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q soundfile", "Install soundfile"),
    ("PIP_ROOT_USER_ACTION=ignore pip install -q -U git+https://github.com/adefossez/demucs#egg=demucs", "Install demucs from current maintainer repo")
]

# Import the utils module which sets up the environment
from modules import utils
from modules import disable_warnings

# Use LogTools
log_tools = utils.LogTools()

# Execute
log_tools.command_state(commands)

## Step 2: Load Libraries and Discover GPU Resources

In [None]:
# Import necessary libraries
import torch
import numpy as np
import soundfile as sf
import matplotlib.pyplot as plt
from demucs import pretrained
from demucs.apply import apply_model
import torch

# Check to see what GPU resources are available
def get_best_device():
    if torch.cuda.is_available():
        print("Using CUDA")
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        print("Using MPS")
        return torch.device("mps")
    else:
        print("Using CPU")
        return torch.device("cpu")

device = get_best_device()

## Step 3: Load the Audio File

We start by loading an audio file using `soundfile`. The audio needs to be in a format supported by `Demucs`.

In [None]:
# Load the audio file
audio_filepath = "../../test_pcm.wav"
audio, sample_rate = sf.read(audio_filepath)

# Plot the audio waveform
plt.figure(figsize=(15, 5))
plt.plot(np.linspace(0, len(audio) / sample_rate, num=len(audio)), audio)
plt.title('Audio Waveform')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.show()

# Play the filtered audio from memory
import IPython.display as ipd
ipd.Audio(audio, rate=sample_rate)

## Step 4: Apply Demucs

Next, we apply the Demucs model to the audio file to enhance it.

In [None]:
# Apply Demucs
model = pretrained.get_model('htdemucs')
model.to('cuda' if torch.cuda.is_available() else 'cpu')
model.eval()

audio_stereo = np.tile(audio, (2, 1))
audio_stereo = np.expand_dims(audio_stereo, axis=0)
waveform_tensor = torch.tensor(audio_stereo, dtype=torch.float32)

with torch.no_grad():
    sources = apply_model(model, waveform_tensor, split=True, overlap=0.25)[0]

vocals = sources[3].cpu().numpy()
vocals = vocals / np.max(np.abs(vocals))
vocals_mono = vocals.mean(axis=0)

# Plot the enhanced audio waveform
plt.figure(figsize=(15, 5))
plt.plot(np.linspace(0, len(vocals_mono) / sample_rate, num=len(vocals_mono)), vocals_mono)
plt.title('Enhanced Audio Waveform')
plt.xlabel('Time [s]')
plt.ylabel('Amplitude')
plt.show()


## Step 5: Save & Listen to Enhanced Audio

We save the enhanced audio to a new file and inspect the before and after results.

In [None]:
import IPython.display as ipd
from IPython.display import display, HTML

# Save the enhanced audio
output_filepath = "../enhanced_audio.wav"
# sf.write(output_filepath, vocals_mono, sample_rate) # Disabled as we don't
# need the file
print(f"Enhanced audio saved to {output_filepath}")

# Function to convert a matplotlib plot to a base64 encoded PNG image
def plt_to_base64(x, y, title):
    """Convert a matplotlib plot to a base64 encoded PNG image."""
    import io
    import base64
    plt.figure(figsize=(7, 5))
    plt.plot(x, y)
    plt.title(title)
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    buf.seek(0)
    image_base64 = base64.b64encode(buf.read()).decode('utf-8')
    plt.close()
    return image_base64

# Generate the waveforms for the original and noise-reduced audio
time_original = np.linspace(0, len(audio) / sample_rate, num=len(audio))
time_reduced = np.linspace(0, len(vocals_mono) / sample_rate, num=len(vocals_mono))

# Create the HTML layout for plots and audio widgets side by side
html_content = f"""
<div style="display: flex; justify-content: space-around; align-items: flex-start;">
    <div>
        <h4>Original Audio</h4>
        <img src="data:image/png;base64,{plt_to_base64(time_original, audio, 'Original Audio')}" alt="Original Audio Waveform"/>
        <br>
        {ipd.Audio(audio, rate=sample_rate)._repr_html_()}
    </div>
    <div>
        <h4>Noise-Reduced Audio</h4>
        <img src="data:image/png;base64,{plt_to_base64(time_reduced, vocals_mono, 'Demucs Filtered Vocals')}" alt="Demucs Filtered Vocal Waveform"/>
        <br>
        {ipd.Audio(vocals_mono, rate=sample_rate)._repr_html_()}
    </div>
</div>
"""

# Display the HTML content
display(HTML(html_content))

## Step 6: Free up Resources
*Remove any local files and free up GPU resources.*

Press the large red button below to get started! 🚀

In [None]:
# Remove the output file
!rm -rf {output_filepath}
print("Local files deleted")

# Free up GPU memory
torch.cuda.empty_cache()
print("GPU memory freed")

## Conclusion

In this notebook, we demonstrated how to enhance audio using the `Demucs` library. We loaded and preprocessed the audio, applied the Demucs model and visualized the enhanced audio.