In [50]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import librosa
import soundfile as sf
from torch.utils.data import Dataset, DataLoader, Subset
from tqdm import tqdm
import musdb
from neural_net import UNet
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import autocast, GradScaler
# ── Hyperparameters & Paths ───────────────────────────────────────────────────
N_FFT         = 1024
HOP_LENGTH    = 256
MAX_FRAMES    = 512
OVERLAP       = 64
BATCH_SIZE    = 16
LR            = 1e-3
EPOCHS        = 50
DATA_DIR      = "musdb18wavs/train"
MODEL_PATH    = "models/vocal_separator_unet.pt"

DEVICE = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps")
    if torch.backends.mps.is_available()
    else torch.device("cpu")
)
print(f"➡️ Using device: {DEVICE}")

➡️ Using device: mps


In [None]:
def convert_tracks_to_wav():
    """
        convert the musdb18 dataset to .wav format
    """
    for track in mus.tracks:
        track_name = track.name.replace(" ", "")
        train_test = track.subset
        track_folder = os.path.join(f"musdb18wavs/{train_test}", track_name)
        os.makedirs(track_folder, exist_ok=True)
        mixture_path = os.path.join(track_folder, "mixture.wav")
        sf.write(mixture_path, track.audio, track.rate)
        for source in ["vocals", "drums", "bass", "other"]:
            if source in track.targets:
                source_audio = track.targets[source].audio
                source_path = os.path.join(track_folder, f"{source}.wav")
                sf.write(source_path, source_audio, track.rate)    


In [4]:
def get_stft(path):
    y, sr = librosa.load(path, sr=None)
    stft = librosa.stft(y, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann')
    return stft, sr

In [5]:
def istft(stft):
    return librosa.istft(stft, hop_length=HOP_LENGTH)

In [None]:
def display_stft(stft_matrix, title):
    plt.figure(figsize=(10, 4))
    librosa.display.specshow(librosa.amplitude_to_db(np.abs(stft_matrix), ref=np.max),
                             y_axis='log', x_axis='time')
    plt.title(title)
    plt.colorbar(format='%+2.0f dB')
    plt.tight_layout()
    plt.show()

In [51]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import librosa

class RemoveVocalsDataset(Dataset):
    def __init__(self, base_dir):
        self.segments = []
        step = MAX_FRAMES - OVERLAP

        for track in os.listdir(base_dir):
            mix_p   = os.path.join(base_dir, track, "mixture.wav")
            vocal_p = os.path.join(base_dir, track, "vocals.wav")
            if not (os.path.exists(mix_p) and os.path.exists(vocal_p)):
                continue

            # load mono audio
            mix_y,  _ = librosa.load(mix_p,   sr=None, mono=True)
            voc_y,  _ = librosa.load(vocal_p, sr=None, mono=True)

            # STFT → magnitude
            mix_stft = librosa.stft(mix_y,  n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann")
            voc_stft = librosa.stft(voc_y,  n_fft=N_FFT, hop_length=HOP_LENGTH, window="hann")
            mix_mag  = np.abs(mix_stft)
            voc_mag  = np.abs(voc_stft)

            # background magnitude & mask
            bg_mag = np.clip(mix_mag - voc_mag, 0.0, None)
            mask   = np.clip(bg_mag / (mix_mag + 1e-8), 0.0, 1.0)

            # log‐scale mixture input
            mix_log = np.log1p(mix_mag)

            T = mix_log.shape[1]
            for t0 in range(0, T - MAX_FRAMES + 1, step):
                m_chunk    = mix_log[:,   t0 : t0 + MAX_FRAMES]
                mask_chunk = mask    [:,   t0 : t0 + MAX_FRAMES]

                self.segments.append((
                    torch.from_numpy(m_chunk).unsqueeze(0).float(),    # [1, freq, frames]
                    torch.from_numpy(mask_chunk).unsqueeze(0).float()  # [1, freq, frames]
                ))

    def __len__(self):
        return len(self.segments)

    def __getitem__(self, idx):
        return self.segments[idx]


In [31]:
class SeparationDataset(Dataset):
    def __init__(self, base_dir):
        self.segments = []
        step = MAX_FRAMES - OVERLAP
        for track in os.listdir(base_dir):
            mix_p   = os.path.join(base_dir, track, "mixture.wav")
            vocal_p = os.path.join(base_dir, track, "vocals.wav")
            if not (os.path.exists(mix_p) and os.path.exists(vocal_p)):
                continue
            mix_y,   _ = librosa.load(mix_p, sr=None, mono=True)
            vocal_y, _ = librosa.load(vocal_p, sr=None, mono=True)
            mix_stft   = librosa.stft(mix_y,   n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann')
            vocal_stft = librosa.stft(vocal_y, n_fft=N_FFT, hop_length=HOP_LENGTH, window='hann')
            mix_mag    = np.abs(mix_stft)
            vocal_mag  = np.abs(vocal_stft)

            T = mix_mag.shape[1]
            for t0 in range(0, T - MAX_FRAMES + 1, step):
                m_chunk = mix_mag[:,  t0:t0+MAX_FRAMES]
                v_chunk = vocal_mag[:,t0:t0+MAX_FRAMES]
                is_silent = (v_chunk.sum() < 1e-3)
                if is_silent and np.random.rand() > 0.5:
                    continue

                mask    = np.clip(v_chunk / (m_chunk + 1e-8), 0.0, 1.0)
                mix_log = np.log1p(m_chunk)
                self.segments.append((
                    torch.from_numpy(mix_log).unsqueeze(0).float(),
                    torch.from_numpy(mask).unsqueeze(0).float()
                ))

    def __len__(self):
        return len(self.segments)
    def __getitem__(self, idx):
        return self.segments[idx]

In [None]:
def spectral_convergence(true_spec, est_spec):
    true_mag = true_spec.abs()
    est_mag  = est_spec.abs()
    num      = torch.norm(true_mag - est_mag, p='fro')
    den      = torch.norm(true_mag,           p='fro')
    return num / (den + 1e-8)

def magnitude_loss(true_spec, est_spec):
    return F.l1_loss(est_spec.abs(), true_spec.abs())

def waveform_loss(true_wave, est_wave):
    return F.l1_loss(est_wave, true_wave)

In [52]:
full_ds      = RemoveVocalsDataset(DATA_DIR)
indices      = torch.randperm(len(full_ds))[:1000].tolist()
subset_ds    = Subset(full_ds, indices)
loader       = DataLoader(full_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)

In [53]:
model2 = UNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.L1Loss()

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0

    for mix_log, mask_tgt in tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        mix_log  = mix_log.to(DEVICE)
        mask_tgt = mask_tgt.to(DEVICE)

        pred = model(mix_log)
        loss = criterion(pred, mask_tgt)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(loader)
    print(f"Epoch {epoch} avg loss: {avg_loss:.4f}")

Epoch 1/50: 100%|██████████| 546/546 [03:19<00:00,  2.73it/s]


Epoch 1 avg loss: 0.2466


Epoch 2/50: 100%|██████████| 546/546 [03:12<00:00,  2.83it/s]


Epoch 2 avg loss: 0.2373


Epoch 3/50: 100%|██████████| 546/546 [03:13<00:00,  2.83it/s]


Epoch 3 avg loss: 0.2304


Epoch 4/50: 100%|██████████| 546/546 [03:15<00:00,  2.80it/s]


Epoch 4 avg loss: 0.2275


Epoch 5/50: 100%|██████████| 546/546 [03:20<00:00,  2.72it/s]


Epoch 5 avg loss: 0.2245


Epoch 6/50: 100%|██████████| 546/546 [03:17<00:00,  2.76it/s]


Epoch 6 avg loss: 0.2229


Epoch 7/50: 100%|██████████| 546/546 [03:21<00:00,  2.70it/s]


Epoch 7 avg loss: 0.2193


Epoch 8/50: 100%|██████████| 546/546 [03:24<00:00,  2.68it/s]


Epoch 8 avg loss: 0.2161


Epoch 9/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 9 avg loss: 0.2154


Epoch 10/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 10 avg loss: 0.2134


Epoch 11/50: 100%|██████████| 546/546 [03:23<00:00,  2.69it/s]


Epoch 11 avg loss: 0.2155


Epoch 12/50: 100%|██████████| 546/546 [03:23<00:00,  2.69it/s]


Epoch 12 avg loss: 0.2118


Epoch 13/50: 100%|██████████| 546/546 [03:22<00:00,  2.69it/s]


Epoch 13 avg loss: 0.2102


Epoch 14/50: 100%|██████████| 546/546 [03:22<00:00,  2.69it/s]


Epoch 14 avg loss: 0.2089


Epoch 15/50: 100%|██████████| 546/546 [03:23<00:00,  2.69it/s]


Epoch 15 avg loss: 0.2078


Epoch 16/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 16 avg loss: 0.2080


Epoch 17/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 17 avg loss: 0.2069


Epoch 18/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 18 avg loss: 0.2044


Epoch 19/50: 100%|██████████| 546/546 [03:23<00:00,  2.68it/s]


Epoch 19 avg loss: 0.2062


Epoch 20/50: 100%|██████████| 546/546 [03:23<00:00,  2.68it/s]


Epoch 20 avg loss: 0.2036


Epoch 21/50: 100%|██████████| 546/546 [03:21<00:00,  2.70it/s]


Epoch 21 avg loss: 0.2046


Epoch 22/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 22 avg loss: 0.2016


Epoch 23/50: 100%|██████████| 546/546 [03:23<00:00,  2.69it/s]


Epoch 23 avg loss: 0.2006


Epoch 24/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 24 avg loss: 0.2018


Epoch 25/50: 100%|██████████| 546/546 [03:22<00:00,  2.70it/s]


Epoch 25 avg loss: 0.2004


Epoch 26/50: 100%|██████████| 546/546 [03:20<00:00,  2.72it/s]


Epoch 26 avg loss: 0.1992


Epoch 27/50: 100%|██████████| 546/546 [03:23<00:00,  2.68it/s]


Epoch 27 avg loss: 0.1981


Epoch 28/50: 100%|██████████| 546/546 [03:26<00:00,  2.64it/s]


Epoch 28 avg loss: 0.1978


Epoch 29/50: 100%|██████████| 546/546 [03:24<00:00,  2.67it/s]


Epoch 29 avg loss: 0.1976


Epoch 30/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 30 avg loss: 0.1972


Epoch 31/50: 100%|██████████| 546/546 [03:27<00:00,  2.63it/s]


Epoch 31 avg loss: 0.1984


Epoch 32/50: 100%|██████████| 546/546 [03:25<00:00,  2.65it/s]


Epoch 32 avg loss: 0.1958


Epoch 33/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 33 avg loss: 0.1955


Epoch 34/50: 100%|██████████| 546/546 [03:25<00:00,  2.65it/s]


Epoch 34 avg loss: 0.1952


Epoch 35/50: 100%|██████████| 546/546 [03:27<00:00,  2.63it/s]


Epoch 35 avg loss: 0.1951


Epoch 36/50: 100%|██████████| 546/546 [03:26<00:00,  2.65it/s]


Epoch 36 avg loss: 0.1941


Epoch 37/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 37 avg loss: 0.1927


Epoch 38/50: 100%|██████████| 546/546 [03:26<00:00,  2.64it/s]


Epoch 38 avg loss: 0.1931


Epoch 39/50: 100%|██████████| 546/546 [03:25<00:00,  2.65it/s]


Epoch 39 avg loss: 0.1927


Epoch 40/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 40 avg loss: 0.1917


Epoch 41/50: 100%|██████████| 546/546 [03:25<00:00,  2.65it/s]


Epoch 41 avg loss: 0.1914


Epoch 42/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 42 avg loss: 0.1920


Epoch 43/50: 100%|██████████| 546/546 [03:23<00:00,  2.68it/s]


Epoch 43 avg loss: 0.1913


Epoch 44/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 44 avg loss: 0.1912


Epoch 45/50: 100%|██████████| 546/546 [03:23<00:00,  2.68it/s]


Epoch 45 avg loss: 0.1904


Epoch 46/50: 100%|██████████| 546/546 [03:21<00:00,  2.70it/s]


Epoch 46 avg loss: 0.1910


Epoch 47/50: 100%|██████████| 546/546 [03:25<00:00,  2.66it/s]


Epoch 47 avg loss: 0.1892


Epoch 48/50: 100%|██████████| 546/546 [03:25<00:00,  2.65it/s]


Epoch 48 avg loss: 0.1887


Epoch 49/50: 100%|██████████| 546/546 [03:24<00:00,  2.67it/s]


Epoch 49 avg loss: 0.1898


Epoch 50/50: 100%|██████████| 546/546 [03:21<00:00,  2.71it/s]


Epoch 50 avg loss: 0.1887


In [35]:
model = UNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = nn.L1Loss()

for epoch in range(1, EPOCHS + 1):
    model.train()
    running_loss = 0.0

    for mix_log, mask_tgt in tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}"):
        mix_log  = mix_log.to(DEVICE)
        mask_tgt = mask_tgt.to(DEVICE)

        pred = model(mix_log)
        loss = criterion(pred, mask_tgt)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(loader)
    print(f"Epoch {epoch} avg loss: {avg_loss:.4f}")

Epoch 1/100: 100%|██████████| 505/505 [02:55<00:00,  2.88it/s]


Epoch 1 avg loss: 0.3418


Epoch 2/100: 100%|██████████| 505/505 [02:51<00:00,  2.94it/s]


Epoch 2 avg loss: 0.2957


Epoch 3/100: 100%|██████████| 505/505 [02:55<00:00,  2.87it/s]


Epoch 3 avg loss: 0.2776


Epoch 4/100: 100%|██████████| 505/505 [02:54<00:00,  2.89it/s]


Epoch 4 avg loss: 0.2689


Epoch 5/100: 100%|██████████| 505/505 [02:55<00:00,  2.87it/s]


Epoch 5 avg loss: 0.2609


Epoch 6/100: 100%|██████████| 505/505 [02:55<00:00,  2.87it/s]


Epoch 6 avg loss: 0.2584


Epoch 7/100: 100%|██████████| 505/505 [02:56<00:00,  2.86it/s]


Epoch 7 avg loss: 0.2550


Epoch 8/100: 100%|██████████| 505/505 [02:56<00:00,  2.86it/s]


Epoch 8 avg loss: 0.2528


Epoch 9/100: 100%|██████████| 505/505 [02:54<00:00,  2.89it/s]


Epoch 9 avg loss: 0.2504


Epoch 10/100: 100%|██████████| 505/505 [02:56<00:00,  2.86it/s]


Epoch 10 avg loss: 0.2477


Epoch 11/100: 100%|██████████| 505/505 [02:56<00:00,  2.85it/s]


Epoch 11 avg loss: 0.2465


Epoch 12/100: 100%|██████████| 505/505 [02:56<00:00,  2.87it/s]


Epoch 12 avg loss: 0.2454


Epoch 13/100: 100%|██████████| 505/505 [02:56<00:00,  2.87it/s]


Epoch 13 avg loss: 0.2449


Epoch 14/100: 100%|██████████| 505/505 [02:59<00:00,  2.81it/s]


Epoch 14 avg loss: 0.2433


Epoch 15/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 15 avg loss: 0.2408


Epoch 16/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 16 avg loss: 0.2406


Epoch 17/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 17 avg loss: 0.2395


Epoch 18/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 18 avg loss: 0.2389


Epoch 19/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 19 avg loss: 0.2393


Epoch 20/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 20 avg loss: 0.2377


Epoch 21/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 21 avg loss: 0.2361


Epoch 22/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 22 avg loss: 0.2358


Epoch 23/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 23 avg loss: 0.2360


Epoch 24/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 24 avg loss: 0.2346


Epoch 25/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 25 avg loss: 0.2339


Epoch 26/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 26 avg loss: 0.2341


Epoch 27/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 27 avg loss: 0.2339


Epoch 28/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 28 avg loss: 0.2327


Epoch 29/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 29 avg loss: 0.2319


Epoch 30/100: 100%|██████████| 505/505 [03:02<00:00,  2.76it/s]


Epoch 30 avg loss: 0.2316


Epoch 31/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 31 avg loss: 0.2313


Epoch 32/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 32 avg loss: 0.2311


Epoch 33/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 33 avg loss: 0.2301


Epoch 34/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 34 avg loss: 0.2299


Epoch 35/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 35 avg loss: 0.2299


Epoch 36/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 36 avg loss: 0.2296


Epoch 37/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 37 avg loss: 0.2286


Epoch 38/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 38 avg loss: 0.2279


Epoch 39/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 39 avg loss: 0.2281


Epoch 40/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 40 avg loss: 0.2279


Epoch 41/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 41 avg loss: 0.2283


Epoch 42/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 42 avg loss: 0.2269


Epoch 43/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 43 avg loss: 0.2262


Epoch 44/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 44 avg loss: 0.2268


Epoch 45/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 45 avg loss: 0.2263


Epoch 46/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 46 avg loss: 0.2262


Epoch 47/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 47 avg loss: 0.2249


Epoch 48/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 48 avg loss: 0.2254


Epoch 49/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 49 avg loss: 0.2252


Epoch 50/100: 100%|██████████| 505/505 [03:00<00:00,  2.79it/s]


Epoch 50 avg loss: 0.2238


Epoch 51/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 51 avg loss: 0.2239


Epoch 52/100: 100%|██████████| 505/505 [03:01<00:00,  2.79it/s]


Epoch 52 avg loss: 0.2236


Epoch 53/100: 100%|██████████| 505/505 [03:00<00:00,  2.80it/s]


Epoch 53 avg loss: 0.2242


Epoch 54/100: 100%|██████████| 505/505 [03:01<00:00,  2.78it/s]


Epoch 54 avg loss: 0.2228


Epoch 55/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 55 avg loss: 0.2228


Epoch 56/100: 100%|██████████| 505/505 [02:55<00:00,  2.87it/s]


Epoch 56 avg loss: 0.2243


Epoch 57/100: 100%|██████████| 505/505 [02:54<00:00,  2.90it/s]


Epoch 57 avg loss: 0.2228


Epoch 58/100: 100%|██████████| 505/505 [02:55<00:00,  2.87it/s]


Epoch 58 avg loss: 0.2230


Epoch 59/100: 100%|██████████| 505/505 [02:57<00:00,  2.85it/s]


Epoch 59 avg loss: 0.2218


Epoch 60/100: 100%|██████████| 505/505 [02:59<00:00,  2.82it/s]


Epoch 60 avg loss: 0.2219


Epoch 61/100: 100%|██████████| 505/505 [02:57<00:00,  2.84it/s]


Epoch 61 avg loss: 0.2218


Epoch 62/100: 100%|██████████| 505/505 [02:57<00:00,  2.84it/s]


Epoch 62 avg loss: 0.2220


Epoch 63/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 63 avg loss: 0.2213


Epoch 64/100: 100%|██████████| 505/505 [03:03<00:00,  2.76it/s]


Epoch 64 avg loss: 0.2208


Epoch 65/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 65 avg loss: 0.2219


Epoch 66/100: 100%|██████████| 505/505 [03:03<00:00,  2.76it/s]


Epoch 66 avg loss: 0.2217


Epoch 67/100: 100%|██████████| 505/505 [03:03<00:00,  2.76it/s]


Epoch 67 avg loss: 0.2203


Epoch 68/100: 100%|██████████| 505/505 [03:02<00:00,  2.76it/s]


Epoch 68 avg loss: 0.2201


Epoch 69/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 69 avg loss: 0.2206


Epoch 70/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 70 avg loss: 0.2202


Epoch 71/100: 100%|██████████| 505/505 [03:02<00:00,  2.76it/s]


Epoch 71 avg loss: 0.2208


Epoch 72/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 72 avg loss: 0.2195


Epoch 73/100: 100%|██████████| 505/505 [03:02<00:00,  2.76it/s]


Epoch 73 avg loss: 0.2199


Epoch 74/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 74 avg loss: 0.2190


Epoch 75/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 75 avg loss: 0.2186


Epoch 76/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 76 avg loss: 0.2199


Epoch 77/100: 100%|██████████| 505/505 [03:02<00:00,  2.76it/s]


Epoch 77 avg loss: 0.2186


Epoch 78/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 78 avg loss: 0.2187


Epoch 79/100: 100%|██████████| 505/505 [03:03<00:00,  2.76it/s]


Epoch 79 avg loss: 0.2181


Epoch 80/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 80 avg loss: 0.2189


Epoch 81/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 81 avg loss: 0.2180


Epoch 82/100: 100%|██████████| 505/505 [03:05<00:00,  2.73it/s]


Epoch 82 avg loss: 0.2182


Epoch 83/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 83 avg loss: 0.2177


Epoch 84/100: 100%|██████████| 505/505 [03:05<00:00,  2.72it/s]


Epoch 84 avg loss: 0.2175


Epoch 85/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 85 avg loss: 0.2178


Epoch 86/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 86 avg loss: 0.2169


Epoch 87/100: 100%|██████████| 505/505 [03:02<00:00,  2.76it/s]


Epoch 87 avg loss: 0.2170


Epoch 88/100: 100%|██████████| 505/505 [03:03<00:00,  2.76it/s]


Epoch 88 avg loss: 0.2167


Epoch 89/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 89 avg loss: 0.2173


Epoch 90/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 90 avg loss: 0.2157


Epoch 91/100: 100%|██████████| 505/505 [03:05<00:00,  2.73it/s]


Epoch 91 avg loss: 0.2162


Epoch 92/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 92 avg loss: 0.2166


Epoch 93/100: 100%|██████████| 505/505 [03:03<00:00,  2.75it/s]


Epoch 93 avg loss: 0.2168


Epoch 94/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 94 avg loss: 0.2159


Epoch 95/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 95 avg loss: 0.2154


Epoch 96/100: 100%|██████████| 505/505 [03:02<00:00,  2.77it/s]


Epoch 96 avg loss: 0.2156


Epoch 97/100: 100%|██████████| 505/505 [03:04<00:00,  2.74it/s]


Epoch 97 avg loss: 0.2157


Epoch 98/100: 100%|██████████| 505/505 [03:05<00:00,  2.72it/s]


Epoch 98 avg loss: 0.2146


Epoch 99/100: 100%|██████████| 505/505 [03:04<00:00,  2.73it/s]


Epoch 99 avg loss: 0.2150


Epoch 100/100: 100%|██████████| 505/505 [03:06<00:00,  2.70it/s]

Epoch 100 avg loss: 0.2148





In [None]:
    model = UNet().to(DEVICE)
    state = torch.load("/Users/drewmedina/Documents/GitHub/AI-Stem-Separation/models/vocal_separator_unet6.pth", map_location=DEVICE)
    model.load_state_dict(state)
    model.to(DEVICE)
    model.eval()

  state = torch.load("/Users/drewmedina/Documents/GitHub/AI-Stem-Separation/models/vocal_separator_unet4.pth", map_location=DEVICE)


UNet(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (middle): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [56]:
import numpy as np
import torch
import librosa
import soundfile as sf

def separate_vocals(mixture_path, output_path="predicted_vocals.wav"):
    S, sr = get_stft(mixture_path)
    mag   = np.abs(S)
    phase = np.angle(S)
    T_tot    = mag.shape[1]
    pred_mag = np.zeros_like(mag)
    counts   = np.zeros(T_tot)
    step     = MAX_FRAMES - OVERLAP
    t        = 0
    while t + MAX_FRAMES <= T_tot:
        seg     = mag[:, t : t + MAX_FRAMES]            # (513,256)
        mix_log = np.log1p(seg)[None, None, :, :]       # (1,1,513,256)
        m_t = torch.from_numpy(mix_log).float().to(DEVICE)
        with torch.no_grad():
            out  = model(m_t)                           # (1,1,513,256)
            pred = out.squeeze(0).squeeze(0).cpu().numpy()  # (513,256)
        pred_mag[:, t : t + MAX_FRAMES] += seg * pred
        counts[  t : t + MAX_FRAMES] += 1
        t += step
    if t < T_tot:
        rem = mag[:, t:]
        pad = MAX_FRAMES - rem.shape[1]
        seg = np.pad(rem, ((0,0),(0,pad)), mode="constant")
        mix_log = np.log1p(seg)[None, None, :, :]
        m_t     = torch.from_numpy(mix_log).float().to(DEVICE)
        with torch.no_grad():
            out  = model(m_t)                          
            pred = out.squeeze(0).squeeze(0).cpu().numpy()[:, : rem.shape[1]]
        pred_mag[:, t:] += rem * pred
        counts[t:]     += 1
    counts = np.maximum(counts, 1.0)
    pred_mag /= counts[None, :]
    S_est = pred_mag * np.exp(1j * phase)
    y_est = istft(S_est)
    sf.write(output_path, y_est, sr)
    print(f"🎉 Saved separated vocals to {output_path}")
separate_vocals("figures (mastered).wav")


🎉 Saved separated vocals to predicted_vocals.wav


In [None]:
MODEL_SAVE_PATH = "models/vocal_separator_unet5.pth"
os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"Model saved to {MODEL_SAVE_PATH}")


In [26]:
import museval
mus = musdb.DB(root="musdb18", subsets="test")
# we'll collect all track‐wise scores here
results = museval.EvalStore(
    frames_agg='median',
    tracks_agg='median'
)
for i, track in enumerate(mus.tracks):
    print(f"▶ Processing {track}")
    track_path = f"musdb18wavs/test/{track}/mixture.wav"
    track_path = track_path.replace(" ", "")
    # 1) write out your model's vocal estimate
    out_vocals = os.path.join("tests", f"{track}__vocals.wav")
    separate_vocals(track_path, out_vocals)

    # 2) load back the estimate & the original mixture
    mix, _       = sf.read(track_path,   always_2d=True)  # → (nsamples, nch)
    est_vocals, _ = sf.read(out_vocals, always_2d=True)   # → (nsamples, nvoc_ch)
    # if your model output is mono, nvoc_ch==1, so tile it to stereo:
    if est_vocals.shape[1] == 1 and mix.shape[1] == 2:
        est_vocals = np.repeat(est_vocals, 2, axis=1)     # → (nsamples, 2)

    # now subtraction works
    est_accom = mix - est_vocals  
    estimates = {
        'vocals'       : est_vocals,
        'accompaniment': est_accom

    }

    # 4) compute BSSeval-v4 metrics for this track
    scores = museval.eval_mus_track(
        track,
        estimates,
        output_dir=None     # don't need the intermediate JSON files
    )
    results.add_track(scores)
    if i == 5:
        break

# ——— Aggregate & pretty-print ————————————————————————————————
print("\n\n===== AGGREGATED RESULTS =====\n")
print(results)  # prints something like your matrix: µ ± σ for SDR, SIR, SAR, ISR, etc.

# if you want the raw pandas DataFrame for further processing:
df = results.save('my_method.pandas')  # returns a DataFrame of all metrics
print("\n\nRaw DataFrame:\n", df)

▶ Processing AM Contra - Heart Peripheral
🎉 Saved separated vocals to tests/AM Contra - Heart Peripheral__vocals.wav
▶ Processing Al James - Schoolboy Facination
🎉 Saved separated vocals to tests/Al James - Schoolboy Facination__vocals.wav
▶ Processing Angels In Amplifiers - I'm Alright
🎉 Saved separated vocals to tests/Angels In Amplifiers - I'm Alright__vocals.wav
▶ Processing Arise - Run Run Run
🎉 Saved separated vocals to tests/Arise - Run Run Run__vocals.wav
▶ Processing BKS - Bulldozer
🎉 Saved separated vocals to tests/BKS - Bulldozer__vocals.wav
▶ Processing BKS - Too Much
🎉 Saved separated vocals to tests/BKS - Too Much__vocals.wav


===== AGGREGATED RESULTS =====

Aggrated Scores (median over frames, median over tracks)
vocals          ==> SDR:   2.923  SIR:   3.322  ISR:   7.139  SAR:   6.840  
accompaniment   ==> SDR:   7.489  SIR:  11.241  ISR:  10.775  SAR:  11.854  



Raw DataFrame:
 None
