In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from tqdm import tqdm
from cnn_noah import CNN

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [2]:
class SpectrogramDataset(Dataset):
    """
    Custom Dataset for loading spectrograms and source targets from .pt files.
    Assumes data_dir contains subfolders for each song, each with multiple clip folders, each containing mix.pt and one .pt per source.
    """
    def __init__(self, data_dir, sources=['vocals', 'drums', 'bass', 'other']):
        self.data_dir = data_dir
        self.sources = sources
        self.samples = []

        for song in os.listdir(data_dir):
            if song == "Remember December - C U Next Time":
                continue  # no bass.pt for this song

            song_path = os.path.join(data_dir, song)
            for clip in os.listdir(song_path):
                clip_path = os.path.join(song_path, clip)
                mix_path = os.path.join(clip_path, 'mix.pt')
                target_paths = [os.path.join(clip_path, f'{src}.pt') for src in sources]
                self.samples.append((mix_path, target_paths))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        mix_path, target_paths = self.samples[idx]
        mix = torch.load(mix_path)  # shape: [freq, time]
        mix = mix.unsqueeze(0)  # [1, freq, time]
        targets = torch.stack([torch.load(tp) for tp in target_paths], dim=0)  # [n_sources, freq, time]
        return mix.float(), targets.float()

In [3]:
# Hyperparameters
batch_size = 4
epochs = 10
learning_rate = 1e-3

sources = ['vocals', 'drums', 'bass', 'other']
train_dir = '../data/musdb18hq/spectrograms/train'
test_dir = '../data/musdb18hq/spectrograms/test'
train_dataset = SpectrogramDataset(train_dir)
test_dataset = SpectrogramDataset(test_dir)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
model = CNN(n_sources=len(sources)).to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(1, epochs + 1):
    model.train()
    train_loss = 0
    for mix, targets in tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}'):
        mix, targets = mix.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(mix)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * mix.size(0)
    train_loss /= len(train_loader.dataset)
    print(f'Epoch {epoch}: Train Loss = {train_loss:.4f}')

Epoch 1/10: 100%|██████████| 553/553 [02:10<00:00,  4.25it/s]


Epoch 1: Train Loss = 0.0233


Epoch 2/10: 100%|██████████| 553/553 [02:07<00:00,  4.32it/s]


Epoch 2: Train Loss = 0.0201


Epoch 3/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 3: Train Loss = 0.0196


Epoch 4/10: 100%|██████████| 553/553 [02:08<00:00,  4.29it/s]


Epoch 4: Train Loss = 0.0193


Epoch 5/10: 100%|██████████| 553/553 [02:08<00:00,  4.31it/s]


Epoch 5: Train Loss = 0.0188


Epoch 6/10: 100%|██████████| 553/553 [02:08<00:00,  4.32it/s]


Epoch 6: Train Loss = 0.0187


Epoch 7/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 7: Train Loss = 0.0185


Epoch 8/10: 100%|██████████| 553/553 [02:08<00:00,  4.30it/s]


Epoch 8: Train Loss = 0.0183


Epoch 9/10: 100%|██████████| 553/553 [02:08<00:00,  4.31it/s]


Epoch 9: Train Loss = 0.0182


Epoch 10/10: 100%|██████████| 553/553 [02:10<00:00,  4.25it/s]

Epoch 10: Train Loss = 0.0180





In [5]:
torch.save(model.state_dict(), "cnn_noah.pth")