In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from cnn_noah import CNN
from unet_noah import UNet

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


## Dataset Creation

In [2]:
class SpectrogramDataset(Dataset):
    """
    Custom Dataset for loading spectrograms and source targets from .pt files.
    Assumes data_dir contains subfolders for each song, each with multiple clip folders, each containing mix.pt and one .pt per source.
    """
    def __init__(self, data_dir, sources=['vocals', 'drums', 'bass', 'other']):
        self.data_dir = data_dir
        self.sources = sources
        self.samples = []

        for song in os.listdir(data_dir):
            if song == "Remember December - C U Next Time":
                continue  # no bass.pt for this song

            song_path = os.path.join(data_dir, song)
            for clip in os.listdir(song_path):
                clip_path = os.path.join(song_path, clip)
                mix_path = os.path.join(clip_path, 'mix.pt')
                target_paths = [os.path.join(clip_path, f'{src}.pt') for src in sources]
                self.samples.append((mix_path, target_paths))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mix_path, target_paths = self.samples[idx]
        mix = torch.load(mix_path)  # shape: [freq, time]
        mix = mix.unsqueeze(0)  # [1, freq, time]
        targets = torch.stack([torch.load(tp) for tp in target_paths], dim=0)  # [n_sources, freq, time]
        return mix.float(), targets.float()

## Data Loading, Training, & Testing

In [3]:
# TODO: hyper-parameter tuning
# Hyper-parameters
BATCH_SIZE = 4
EPOCHS = 5  # TODO: increase when model architecture finalized
LEARNING_RATE = 1e-3

# Data loaders
sources = ['vocals', 'drums', 'bass', 'other']
train_dir = '../data/musdb18hq/spectrograms/train'
test_dir = '../data/musdb18hq/spectrograms/test'
train_dataset = SpectrogramDataset(train_dir)
test_dataset = SpectrogramDataset(test_dir)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def train_model(model, criterion, optimizer):
    for epoch in range(1, EPOCHS + 1):
        model.train()
        train_loss = 0
        
        for mix, targets in tqdm(train_loader, desc=f'Epoch {epoch}/{EPOCHS}'):
            mix, targets = mix.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(mix)
            
            # Crop targets to match outputs shape (for U-Net)
            if outputs.shape != targets.shape:
                _, _, h, w = outputs.shape
                targets = targets[:, :, :h, :w]
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * mix.size(0)
        
        train_loss /= len(train_loader.dataset)
        print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}')
    
# TODO: better evaluation metrics (SDR, etc.)
# TODO: hearing the output .pt files
def test_model(model, criterion):
    model.eval()
    test_loss = 0
    criterion = nn.MSELoss()
    
    with torch.no_grad():
        for mix, targets in test_loader:
            mix, targets = mix.to(device), targets.to(device)
            outputs = model(mix)

            # Crop targets to match outputs shape (for U-Net)
            if outputs.shape != targets.shape:
                _, _, h, w = outputs.shape
                targets = targets[:, :, :h, :w]
            
            loss = criterion(outputs, targets)
            test_loss += loss.item() * mix.size(0)
    
    test_loss /= len(test_loader.dataset)
    print(f"Test Loss = {test_loss:.4f}")

## CNN Training

In [5]:
# TODO: appropriate criterion and optimizer?
cnn_model = CNN(n_sources=len(sources)).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(cnn_model.parameters(), lr=LEARNING_RATE)

train_model(cnn_model, criterion, optimizer)

Epoch 1/5: 100%|██████████| 553/553 [02:06<00:00,  4.36it/s]
Epoch 1/5: 100%|██████████| 553/553 [02:06<00:00,  4.36it/s]


Epoch 1: Train Loss = 0.0247


Epoch 2/5: 100%|██████████| 553/553 [02:07<00:00,  4.33it/s]
Epoch 2/5: 100%|██████████| 553/553 [02:07<00:00,  4.33it/s]


Epoch 2: Train Loss = 0.0202


Epoch 3/5: 100%|██████████| 553/553 [02:08<00:00,  4.29it/s]
Epoch 3/5: 100%|██████████| 553/553 [02:08<00:00,  4.29it/s]


Epoch 3: Train Loss = 0.0197


Epoch 4/5: 100%|██████████| 553/553 [02:08<00:00,  4.32it/s]
Epoch 4/5: 100%|██████████| 553/553 [02:08<00:00,  4.32it/s]


Epoch 4: Train Loss = 0.0194


Epoch 5/5: 100%|██████████| 553/553 [02:07<00:00,  4.34it/s]



Epoch 5: Train Loss = 0.0192


In [6]:
torch.save(cnn_model.state_dict(), "cnn_noah.pth")

## CNN Test Evaluation

In [7]:
cnn_model = CNN(n_sources=len(sources)).to(device)
cnn_model.load_state_dict(torch.load("cnn_noah.pth"))

test_model(cnn_model, criterion)

Test Loss = 0.0204


## U-Net Training

In [8]:
# TODO: appropriate criterion and optimizer?
unet_model = UNet(n_sources=len(sources)).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(unet_model.parameters(), lr=LEARNING_RATE)

train_model(unet_model, criterion, optimizer)

Epoch 1/5: 100%|██████████| 553/553 [00:45<00:00, 12.28it/s]
Epoch 1/5: 100%|██████████| 553/553 [00:45<00:00, 12.28it/s]


Epoch 1: Train Loss = 0.0240


Epoch 2/5: 100%|██████████| 553/553 [00:44<00:00, 12.45it/s]
Epoch 2/5: 100%|██████████| 553/553 [00:44<00:00, 12.45it/s]


Epoch 2: Train Loss = 0.0187


Epoch 3/5: 100%|██████████| 553/553 [00:44<00:00, 12.53it/s]
Epoch 3/5: 100%|██████████| 553/553 [00:44<00:00, 12.53it/s]


Epoch 3: Train Loss = 0.0175


Epoch 4/5: 100%|██████████| 553/553 [00:44<00:00, 12.43it/s]
Epoch 4/5: 100%|██████████| 553/553 [00:44<00:00, 12.43it/s]


Epoch 4: Train Loss = 0.0167


Epoch 5/5: 100%|██████████| 553/553 [00:45<00:00, 12.26it/s]

Epoch 5: Train Loss = 0.0160





In [9]:
torch.save(unet_model.state_dict(), "unet_noah.pth")

## U-Net Test Evaluation

In [12]:
unet_model = UNet(n_sources=len(sources)).to(device)
unet_model.load_state_dict(torch.load("unet_noah.pth"))

test_model(unet_model, criterion)

Test Loss = 0.0159
