In [1]:
from IPython.display import Audio
import torch
import torchaudio
from torchaudio.utils import download_asset
import matplotlib.pyplot as plt
from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS
from mir_eval import separation
from torchaudio.transforms import Fade
from scipy.io.wavfile import write, read
import numpy as np
import wave
import os

In [2]:
importpath = "C:/Users/jakes/Downloads/Nina Simone - Sinnerman.wav"
exportfolder = "C:/Users/jakes/Music/PythonStems/Nina Simone Sinnerman/"`

def separate_sources(
        model,
        mix,
        segment=10.,
        overlap=0.1,
        device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final

bundle = HDEMUCS_HIGH_MUSDB_PLUS
model = bundle.get_model()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
sample_rate = bundle.sample_rate

SAMPLE_SONG = download_asset(importpath)
waveform, sample_rate = torchaudio.load(SAMPLE_SONG)
waveform.to(device)
mixture = waveform

# parameters
segment: int = 10
overlap = 0.1

print("Separating track")

ref = waveform.mean(0)
waveform = (waveform - ref.mean()) / ref.std()  # normalization

sources = separate_sources(
    model,
    waveform[None],
    device=device,
    segment=segment,
    overlap=overlap,
)[0]
sources = sources * ref.std() + ref.mean()

sources_list = model.sources
sources = list(sources)

audios = dict(zip(sources_list, sources))

try:
    os.chdir(exportfolder)
except (FileNotFoundError,FileExistsError):
    os.makedirs(exportfolder)

# Drums Audio
with open(exportfolder+'drums.wav', 'wb') as f:
    f.write(Audio(audios['drums'], rate=sample_rate).data)

# Bass Audio
with open(exportfolder+'bass.wav', 'wb') as f:
    f.write(Audio(audios['bass'], rate=sample_rate).data)

# Vocals Audio
with open(exportfolder+'vocals.wav', 'wb') as f:
    f.write(Audio(audios['vocals'], rate=sample_rate).data)

# Other Audio
with open(exportfolder+'melodies.wav', 'wb') as f:
    f.write(Audio(audios['other'], rate=sample_rate).data)


Separating track
