In [1]:
import torch
import torchaudio
import matplotlib.pyplot as plt
from IPython.display import Audio, Video
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from mir_eval import separation
from torchaudio.transforms import Fade

In [None]:
def plot_specgram(waveform, sample_rate, title="Spectrogram"):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].specgram(waveform[c], Fs=sample_rate)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle(title)
    plt.show(block=False)

In [None]:
def plot_waveform(waveform, sample_rate):
    waveform = waveform.numpy()

    num_channels, num_frames = waveform.shape
    time_axis = torch.arange(0, num_frames) / sample_rate

    figure, axes = plt.subplots(num_channels, 1)
    if num_channels == 1:
        axes = [axes]
    for c in range(num_channels):
        axes[c].plot(time_axis, waveform[c], linewidth=1)
        axes[c].grid(True)
        if num_channels > 1:
            axes[c].set_ylabel(f"Channel {c+1}")
    figure.suptitle("waveform")
    plt.show(block=False)

In [None]:
print(torch.__version__) # 2.0.0
print(torchaudio.__version__) # 2.0.0
print(torchaudio._extension._FFMPEG_INITIALIZED) # True
print(torch.cuda.is_available()) # True

In [None]:
mixture = 'data/demo/A Classic Education - NightOwl/mixture.wav'
bass = 'data/demo/A Classic Education - NightOwl/bass.wav'
drums = 'data/demo/A Classic Education - NightOwl/drums.wav'
other = 'data/demo/A Classic Education - NightOwl/other.wav'
vocals = 'data/demo/A Classic Education - NightOwl/vocals.wav'
metadata = torchaudio.info(mixture)
print(metadata)

In [None]:
waveform, sample_rate = torchaudio.load(mixture)


In [None]:
bass_waveform, bass_sample_rate = torchaudio.load(bass)

In [None]:
plot_waveform(waveform[:,0:10000], sample_rate) # waveform takes ~1 min to plot.  Trimming for example.

In [None]:
plot_specgram(waveform, sample_rate)

In [None]:
plot_waveform(bass_waveform[:,0:10000], bass_sample_rate)

In [None]:
plot_specgram(bass_waveform, bass_sample_rate)

In [None]:
Audio(mixture)

In [None]:
bundle = HDEMUCS_HIGH_MUSDB_PLUS

model = bundle.get_model()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")

In [None]:
def separate_sources(
        model,
        mix,
        segment=10.,
        overlap=0.1,
        device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final

In [None]:
def plot_spectrogram(stft, title="Spectrogram"):
    magnitude = stft.abs()
    spectrogram = 20 * torch.log10(magnitude + 1e-8).numpy()
    figure, axis = plt.subplots(1, 1)
    img = axis.imshow(spectrogram, cmap="viridis", vmin=-60, vmax=0, origin="lower", aspect="auto")
    figure.suptitle(title)
    plt.colorbar(img, ax=axis)
    plt.show()

In [None]:
waveform, sample_rate = torchaudio.load(mixture)  # replace SAMPLE_SONG with desired path for different song
waveform = waveform.to(device)
mixture = waveform

# parameters
segment: int = 10
overlap = 0.1

print("Separating track")

ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()  # normalization

sources = separate_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]
sources = sources * ref.std() + ref.mean()

sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))

In [None]:
N_FFT = 4096
N_HOP = 4
stft = torchaudio.transforms.Spectrogram(
    n_fft=N_FFT,
    hop_length=N_HOP,
    power=None,
)

In [None]:
def output_results(original_source: torch.Tensor, predicted_source: torch.Tensor, source: str):
    print("SDR score is:",
          separation.bss_eval_sources(
              original_source.detach().numpy(),
              predicted_source.detach().numpy())[0].mean())
    plot_spectrogram(stft(predicted_source)[0], f'Spectrogram {source}')
    return Audio(predicted_source, rate=sample_rate)


segment_start = 150
segment_end = 155

frame_start = segment_start * sample_rate
frame_end = segment_end * sample_rate

drums_original = 'data/demo/A Classic Education - NightOwl/drums.wav'
bass_original = 'data/hq/train/A Classic Education - NightOwl/bass.wav'
vocals_original = 'data/demo/A Classic Education - NightOwl/vocals.wav'
other_original = 'data/demo/A Classic Education - NightOwl/other.wav'

drums_spec = audios["drums"][:, frame_start: frame_end].cpu()
drums, sample_rate = torchaudio.load(drums_original)

bass_spec = audios["bass"][:, frame_start: frame_end].cpu()
bass, sample_rate = torchaudio.load(bass_original)

vocals_spec = audios["vocals"][:, frame_start: frame_end].cpu()
vocals, sample_rate = torchaudio.load(vocals_original)

other_spec = audios["other"][:, frame_start: frame_end].cpu()
other, sample_rate = torchaudio.load(other_original)

mix_spec = mixture[:, frame_start: frame_end].cpu()

In [None]:
plot_spectrogram(stft(mix_spec)[0], "Spectrogram Mixture")
Audio(mix_spec, rate=sample_rate)

In [None]:
b

In [None]:
# Drums Clip
output_results(vocals[:, frame_start: frame_end], vocals_spec, "vocals")