In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from cnn_noah import CNN
from unet_noah import UNet

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Dataset Creation

In [2]:
class SpectrogramDataset(Dataset):
    """
    Custom Dataset for loading spectrograms and source targets from .pt files.
    Assumes data_dir contains subfolders for each song, each with multiple clip folders, each containing mix.pt and one .pt per source.
    """
    def __init__(self, data_dir, sources=['vocals', 'drums', 'bass', 'other']):
        self.data_dir = data_dir
        self.sources = sources
        self.samples = []

        for song in os.listdir(data_dir):
            if song == "Remember December - C U Next Time":
                continue  # no bass.pt for this song

            song_path = os.path.join(data_dir, song)
            for clip in os.listdir(song_path):
                clip_path = os.path.join(song_path, clip)
                mix_path = os.path.join(clip_path, 'mix.pt')
                target_paths = [os.path.join(clip_path, f'{src}.pt') for src in sources]
                self.samples.append((mix_path, target_paths))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mix_path, target_paths = self.samples[idx]
        mix = torch.load(mix_path)  # shape: [freq, time]
        mix = mix.unsqueeze(0)  # [1, freq, time]
        targets = torch.stack([torch.load(tp) for tp in target_paths], dim=0)  # [n_sources, freq, time]
        return mix.float(), targets.float()

## Data Loading & Training

In [3]:
# TODO: hyper-parameter tuning
# Hyper-parameters
BATCH_SIZE = 4
EPOCHS = 10
LEARNING_RATE = 1e-3

# Data loaders
sources = ['vocals', 'drums', 'bass', 'other']
train_dir = '../data/musdb18hq/spectrograms/train'
test_dir = '../data/musdb18hq/spectrograms/test'
train_dataset = SpectrogramDataset(train_dir)
test_dataset = SpectrogramDataset(test_dir)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
def train_model(model, criterion, optimizer):
    for epoch in range(1, EPOCHS + 1):
        model.train()
        train_loss = 0
        for mix, targets in tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}'):
            mix, targets = mix.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(mix)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * mix.size(0)
        train_loss /= len(train_loader.dataset)
        print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}')

## CNN Training

In [5]:
# TODO: appropriate criterion and optimizer?
cnn_model = CNN(n_sources=len(sources)).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=LEARNING_RATE)

train_model(cnn_model, criterion, optimizer)

Epoch 1/10: 100%|██████████| 553/553 [04:53<00:00,  1.89it/s]


Epoch 1: Train Loss = 0.0247


Epoch 2/10: 100%|██████████| 553/553 [02:07<00:00,  4.32it/s]


Epoch 2: Train Loss = 0.0202


Epoch 3/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 3: Train Loss = 0.0197


Epoch 4/10: 100%|██████████| 553/553 [02:08<00:00,  4.31it/s]


Epoch 4: Train Loss = 0.0194


Epoch 5/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 5: Train Loss = 0.0192


Epoch 6/10: 100%|██████████| 553/553 [02:08<00:00,  4.31it/s]


Epoch 6: Train Loss = 0.0188


Epoch 7/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 7: Train Loss = 0.0187


Epoch 8/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 8: Train Loss = 0.0187


Epoch 9/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 9: Train Loss = 0.0185


Epoch 10/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]

Epoch 10: Train Loss = 0.0183





In [6]:
torch.save(cnn_model.state_dict(), "cnn_noah.pth")

## CNN Test Evaluation

In [5]:
# TODO: better evaluation metrics (SDR, etc.)
# TODO: hearing the output .pt files
cnn_model = CNN(n_sources=len(sources)).to(device)
cnn_model.load_state_dict(torch.load("cnn_noah.pth"))
cnn_model.eval()

test_loss = 0
criterion = nn.MSELoss()

with torch.no_grad():
    for mix, targets in test_loader:
        mix, targets = mix.to(device), targets.to(device)
        outputs = cnn_model(mix)
        loss = criterion(outputs, targets)
        test_loss += loss.item() * mix.size(0)

test_loss /= len(test_loader.dataset)
print(f"Test Loss = {test_loss:.4f}")

Test Loss = 0.0194


## U-Net Training

In [6]:
# TODO: not working
# TODO: appropriate criterion and optimizer?
unet_model = UNet(n_sources=len(sources)).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(unet_model.parameters(), lr=LEARNING_RATE)

train_model(unet_model, criterion, optimizer)

Epoch 1/10:   0%|          | 0/553 [00:00<?, ?it/s]



RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 214 but got size 215 for tensor number 1 in the list.

In [None]:
torch.save(unet_model.state_dict(), "unet_noah.pth")