In [1]:
import argparse
import glob
import os
import sys
import time
import warnings
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Literal, NamedTuple

import librosa
import numpy as np
import soundfile as sf
import torch
import torch.nn as nn
from tqdm import tqdm

from utils import demix_track, demix_track_demucs, get_model_from_config

In [2]:
@contextmanager
def measure_time(text: str):
    start_time = time.time()
    try:
        yield
    finally:
        elapsed_time = time.time() - start_time
        print(f"{text}: {elapsed_time:.2f} sec")

In [3]:
def get_device():
    if torch.cuda.is_available():
        return torch.device(f"cuda:0")
    else:
        print("CUDA is not avilable. Run inference on CPU. It will be very slow...")
        return "cpu"


device = get_device()
device

device(type='cuda', index=0)

In [4]:
@dataclass
class Model:
    type: str
    config_path: str
    checkpoint_path: str

    # I don't know how to type these
    # model: ???
    # config: ???

    def load_model(self):
        print(f"Loading {self.type} model: {self.checkpoint_path}")
        model, config = get_model_from_config(self.type, self.config_path)

        state_dict = torch.load(self.checkpoint_path)
        if self.type == "htdemucs":
            # Fix for htdemucs pround etrained models
            if "state" in state_dict:
                state_dict = state_dict["state"]
        model.load_state_dict(state_dict)

        model = model.to(device)
        model.eval()

        self.model = model
        self.config = config

    def demix(self, mix: np.ndarray) -> dict[str, np.ndarray]:
        mix = torch.tensor(mix.T, dtype=torch.float32)
        if self.type == "htdemucs":
            res = demix_track_demucs(self.config, self.model, mix, device)
        else:
            res = demix_track(self.config, self.model, mix, device)

        for k in res:
            res[k] = res[k].T

        return res


# Vocal model: BS Roformer (viperx edition)
VOCAL_MODEL = Model(
    "bs_roformer",
    "configs/viperx/model_bs_roformer_ep_317_sdr_12.9755.yaml",
    "results/model_bs_roformer_ep_317_sdr_12.9755.ckpt",
)

# Single stem model: BS Roformer (viperx edition)
OTHER_MODEL = Model(
    "bs_roformer",
    "configs/viperx/model_bs_roformer_ep_937_sdr_10.5309.yaml",
    "results/model_bs_roformer_ep_937_sdr_10.5309.ckpt",
)

# Single stem model: HTDemucs4 FT Drums
DRUMS_MODEL = Model(
    "htdemucs", "configs/config_musdb18_htdemucs.yaml", "results/f7e0c4bc-ba3fe64a.th"
)

# Single stem model: HTDemucs4 FT Bass
BASS_MODEL = Model(
    "htdemucs", "configs/config_musdb18_htdemucs.yaml", "results/d12395a8-e57c48e6.th"
)

with measure_time("Load models"):
    VOCAL_MODEL.load_model()
    OTHER_MODEL.load_model()
    DRUMS_MODEL.load_model()
    BASS_MODEL.load_model()

Loading bs_roformer model: results/model_bs_roformer_ep_317_sdr_12.9755.ckpt
Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda


  return _VF.stft(input, n_fft, hop_length, win_length, window,  # type: ignore[attr-defined]


Loading bs_roformer model: results/model_bs_roformer_ep_937_sdr_10.5309.ckpt
Loading htdemucs model: results/f7e0c4bc-ba3fe64a.th
Loading htdemucs model: results/d12395a8-e57c48e6.th
Load models: 3.12 sec


In [5]:
print(f"{VOCAL_MODEL.config.training.instruments = }")
print(f"{OTHER_MODEL.config.training.instruments = }")
print(f"{DRUMS_MODEL.config.training.instruments = }")
print(f"{BASS_MODEL.config.training.instruments = }")

VOCAL_MODEL.config.training.instruments = ['vocals', 'other']
OTHER_MODEL.config.training.instruments = ['vocals', 'other']
DRUMS_MODEL.config.training.instruments = ['drums', 'bass', 'other', 'vocals']
BASS_MODEL.config.training.instruments = ['drums', 'bass', 'other', 'vocals']


In [None]:
def load_audio(path: str):
    # mix, sr = sf.read(path)
    mix, sr = librosa.load(path, sr=44100, mono=False)
    mix = mix.T

    # Convert mono to stereo if needed
    if len(mix.shape) == 1:
        mix = np.stack([mix, mix], axis=-1)

    return mix, sr


def preview_audio(mix, sr):
    import IPython.display

    return IPython.display.Audio(data=mix.T, rate=sr, normalize=False)


path = R"D:\Soundtracks\Electronic\Cametek (Camellia)\[KCCD-007] [2019.08.12] Confetto x かめりあ - ごーいん!\05. インターネットが遅いさん (Super-Slow-Internet-san).mp3"


try:
    mix, sr = load_audio(path)
except Exception as e:
    print(f"Can't read track: {path}")
    print(f"Error message: {e}")
    raise

preview_audio(mix, sr)

In [7]:
vocals = VOCAL_MODEL.demix(mix)['vocals']
vocals

  out = F.scaled_dot_product_attention(


array([[0., 0.],
       [0., 0.],
       [0., 0.],
       ...,
       [0., 0.],
       [0., 0.],
       [0., 0.]], dtype=float32)

In [None]:
inst = mix - vocals
preview_audio(inst, sr)

In [None]:
other = OTHER_MODEL.demix(inst)['other']
preview_audio(other, sr)

In [None]:
drum_and_bass = inst - other
preview_audio(drum_and_bass, sr)

In [17]:
res = DRUMS_MODEL.demix(mix)
res.keys()

dict_keys(['drums', 'bass', 'other', 'vocals'])

In [None]:
preview_audio(res['drums'], sr)

In [None]:
bass = BASS_MODEL.demix(drum_and_bass)["bass"]
preview_audio(bass, sr)

In [None]:
preview_audio(drum_and_bass - bass, sr)

In [None]:
drums = DRUMS_MODEL.demix(drum_and_bass - bass)["drums"]
preview_audio(drums, sr)

In [None]:
residual = drum_and_bass - bass - drums
preview_audio(residual, sr)

In [None]:
preview_audio(residual + bass + drums + other + vocals - mix, sr)

In [None]:
from pathlib import Path

x = Path(path)
x.with_stem(x.stem + '_vocals').with_suffix('.flac')
