In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

import torch

In [None]:
try:
    from torchaudio.pipelines import HDEMUCS_HIGH_MUSDB_PLUS

except ModuleNotFoundError:
    try:
        import google.colab

        print(
            """
            To enable running this notebook in Google Colab, install nightly
            torch and torchaudio builds by adding the following code block to the top
            of the notebook before running it:
            !pip3 uninstall -y torch torchvision torchaudio
            !pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
            !pip3 install mir_eval
            """
        )
    except ModuleNotFoundError:
        pass
    raise

In [None]:
bundle = HDEMUCS_HIGH_MUSDB_PLUS

model = bundle.get_model()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

sample_rate = bundle.sample_rate

print(f"Sample rate: {sample_rate}")

In [None]:
from torchaudio.transforms import Fade


def separate_sources(
        model,
        mix,
        segment=10.,
        overlap=0.1,
        device=None,
):
    """
    Apply model to a given mixture. Use fade, and add segments together in order to add model segment by segment.

    Args:
        segment (int): segment length in seconds
        device (torch.device, str, or None): if provided, device on which to
            execute the computation, otherwise `mix.device` is assumed.
            When `device` is different from `mix.device`, only local computations will
            be on `device`, while the entire tracks will be stored on `mix.device`.
    """
    if device is None:
        device = mix.device
    else:
        device = torch.device(device)

    batch, channels, length = mix.shape

    chunk_len = int(sample_rate * segment * (1 + overlap))
    start = 0
    end = chunk_len
    overlap_frames = overlap * sample_rate
    fade = Fade(fade_in_len=0, fade_out_len=int(overlap_frames), fade_shape='linear')

    final = torch.zeros(batch, len(model.sources), channels, length, device=device)

    while start < length - overlap_frames:
        chunk = mix[:, :, start:end]
        with torch.no_grad():
            out = model.forward(chunk)
        out = fade(out)
        final[:, :, :, start:end] += out
        if start == 0:
            fade.fade_in_len = int(overlap_frames)
            start += int(chunk_len - overlap_frames)
        else:
            start += chunk_len
        end += chunk_len
        if end >= length:
            fade.fade_out_len = 0
    return final

In [None]:
import torchaudio
from tqdm import tqdm
from scipy.io.wavfile import write

segment = 10
overlap = 0.1
length = 10

def get_split(s):
    waveform, sample_rate = torchaudio.load(s)
    # waveform = torch.concat([waveform, waveform])
    waveform = waveform.to(device)
    mixture = waveform
    
    ref = waveform.mean(0)
    waveform = (waveform - ref.mean()) / ref.std()

    sources = separate_sources(
        model,
        waveform[None],
        device=device,
        segment=segment,
        overlap=overlap,
    )[0]
    sources = sources * ref.std() + ref.mean()

    sources_list = model.sources
    sources = list(sources)

    audios = dict(zip(sources_list, sources))
    return audios, sample_rate

In [None]:
splitted, sample_rate_ = get_split('Radiohead - Creep.mp3')

In [None]:
splitted.keys()

In [None]:
ratio = {
    'drums': 0.7, 'bass': 1.0,
    'other': 0.7, 'vocals': 1.0
}

In [None]:
c = 0
for k in splitted:
    c += splitted[k] * ratio[k]
    
c

In [None]:
import IPython.display as ipd

ipd.Audio(c.numpy()[0], rate = sample_rate_)

In [None]:
splitted['vocals'].numpy()[0]

In [None]:
from scipy.io.wavfile import write

write('creep-vocal.wav', sample_rate_, splitted['vocals'].numpy()[0][60 * sample_rate_ : 120 * sample_rate_])

In [None]:
ipd.Audio('/home/husein/ssd3/so-vits-svc/results/creep-vocal.wav_0key_speaker_0.flac')