In [None]:
import os
import sys
from pathlib import Path
from typing import Generator
from tqdm.notebook import tqdm

import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torchaudio
import torchaudio.transforms as AT
import torchaudio.functional as AF

from sklearn.model_selection import train_test_split

from deepmash.data_processing.constants import *
from deepmash.data_processing.gtzan_stems import *
from deepmash.data_processing.new_dataloader import *

%load_ext autoreload
%autoreload 2

In [None]:
TARGET_SR = 16000           # think its hard to go lower than 16kHz
CHUNK_DURATION_SEC = 15     # all chunks will be exactly this long
MIN_CHUNK_DURATION_SEC = 5  # discard chunks shorter than this (otherwise zero-pad to CHUNK_DURATION_SEC)

# mel-spectrogram settings (using same as cocola for now)
N_MELS = 64
F_MIN = 60 
F_MAX = 7800
WINDOW_SIZE = 1024 # 64ms @ 16kHz (should be power of 2 for efficiency)
HOP_SIZE = 320     # 20ms @ 16kHz

INPUT_ROOT = "path/to/gtzan/stems"
ORIGINALS_ROOT = "path/to/gtzan/originals" 

In [None]:
# ---------- Utility functions ----------

def ensure_same_length(tensors: list[torch.Tensor]) -> list[torch.Tensor]:
    min_len = min(len(t) for t in tensors)
    return [t[:min_len] for t in tensors]


def ensure_2d(x: torch.Tensor) -> torch.Tensor:
    if x.ndim == 1:
        return x.unsqueeze(0)
    if x.ndim != 2:
        raise ValueError(f"Input tensor has {x.ndim} dimensions, expected 1 or 2.")
    return x


def mix_stems(stems: list[torch.Tensor], peak_val=0.98) -> torch.Tensor:
    stems = ensure_same_length(stems)
    mixed: torch.Tensor = sum(stems)  # type: ignore
    max_val = mixed.abs().max()
    if max_val > 0:
        mixed = mixed / max_val * peak_val
    return mixed


def zero_pad_or_clip(x: torch.Tensor, target_len: int) -> torch.Tensor:
    if len(x) >= target_len:
        return x[:target_len]
    pad_len = target_len - len(x)
    return torch.cat([x, torch.zeros(pad_len)], dim=0)


def load_audio(path: Path | str, sr: int | float, frame_offset=0, num_frames=-1):
    y, sr = torchaudio.load(path, frame_offset=frame_offset, num_frames=num_frames)
    if y.shape[0] > 1:
        y = y.mean(dim=0, keepdim=True)  # convert to mono
    y = AF.resample(y, orig_freq=sr, new_freq=TARGET_SR)
    return y  # (1, sr*duration)


def get_chunks(vocals: torch.Tensor, non_vocals: torch.Tensor) -> Generator[tuple[torch.Tensor, torch.Tensor], None, None]:
    chunk_frames = CHUNK_DURATION_SEC * TARGET_SR
    min_chunk_frames = MIN_CHUNK_DURATION_SEC * TARGET_SR

    vocals, non_vocals = ensure_same_length([vocals, non_vocals])

    for i, start in enumerate(range(0, len(vocals), chunk_frames)):
        vocals_chunk = vocals[start:start + chunk_frames]
        non_vocals_chunk = non_vocals[start:start + chunk_frames]

        if len(vocals_chunk) < min_chunk_frames:
            continue
        vocals_chunk = zero_pad_or_clip(vocals_chunk, chunk_frames)
        non_vocals_chunk = zero_pad_or_clip(non_vocals_chunk, chunk_frames)

        yield vocals_chunk, non_vocals_chunk


def get_gtzan_track_folders(root: Path | str):
    return sorted(p for p in Path(root).glob("*/*") if p.is_dir())

# ---------- Mel transform ----------

class ToLogMel(nn.Module):
    def __init__(self):
        super().__init__()
        self.to_melspec = AT.MelSpectrogram(
            sample_rate=TARGET_SR,
            n_mels=N_MELS,
            n_fft=WINDOW_SIZE,
            hop_length=HOP_SIZE,
            f_min=F_MIN,
            f_max=F_MAX,
        )
        self.to_db = AT.AmplitudeToDB()

    def forward(self, x: torch.Tensor):
        return self.to_db(self.to_melspec(x))


# ---------- Main dataset with preprocessing ----------

class GTZANStemsDataset(Dataset):
    def __init__(
        self,
        root_dir: Path | str = INPUT_ROOT,
        originals_root: Path | str = ORIGINALS_ROOT,
        preprocess=True,
        preprocess_transform: nn.Module | None = None,
        runtime_transform: nn.Module | None = None,
        device: str = "cpu",
    ):
        self.root = Path(root_dir)
        self.originals_root = Path(originals_root)
        self.processed_root = self.root.parent / (self.root.name + "-processed")

        self.preprocess_transform = preprocess_transform
        self.runtime_transform = runtime_transform
        self.device = device

        if preprocess:
            print(f"Preprocessing GTZAN stems from {self.root} and originals from {self.originals_root}")
            self._preprocess()

        # After preprocessing, load the chunk list
        self.chunk_dirs = sorted([p for p in self.processed_root.glob("*") if p.is_dir()])

    def _preprocess(self):
        """
        Assuming input files like "`self.root`/blues/blues.000001/{drums|bass|other|vocals}.wav":
        1. load as tensors
        2. convert to mono if in stereo
        3. resample (default 16kHz)
        4. mix all non-vocal stems together and discard originals
        5. chunk into `CHUNK_DURATION_SEC` (default 10s) segments, zero-pad last chunk if needed
        6. apply optional `preprocess_transform` (e.g. mel-spectrogram), make sure shapes are correct
        7. save as `self.processed_root`/blues.000001.chunk{1|2|...}/{non-vocals|vocals}.pt
        """
        os.makedirs(self.processed_root, exist_ok=True)
        track_folders = get_gtzan_track_folders(self.root)

        for track_folder in tqdm(track_folders):
            all_stem_paths = list(track_folder.glob("*.wav"))
            assert {p.stem for p in all_stem_paths} == {"drums", "bass", "other", "vocals"}, \
                f"Missing stems for {track_folder}"

            vocals_path = [p for p in all_stem_paths if p.stem == "vocals"][0]
            non_vocals_paths = [p for p in all_stem_paths if p.stem != "vocals"]

            track_name = track_folder.name
            genre = track_folder.parent.name  # e. g. "blues"
            orig_path = self.originals_root / f"{genre}" / f"{track_name}.wav"
            
            # Load and mix stems
            try:
                vocals = load_audio(vocals_path, sr=TARGET_SR)
                non_vocals = mix_stems([load_audio(p, sr=TARGET_SR) for p in non_vocals_paths])
                original = load_audio(orig_path, sr=TARGET_SR)
            except Exception as e:
                print(f"⚠️ Skipping track {track_name} due to loading error: {e}")
                continue
            
            vocals = vocals.squeeze(0)
            non_vocals = non_vocals.squeeze(0)
            original = original.squeeze(0)

            # Generate aligned chunks
            for i, ((vocals_chunk, non_vocals_chunk), (orig_chunk, _)) in enumerate(
                zip(get_chunks(vocals, non_vocals), get_chunks(original, original))
            ):
                if self.preprocess_transform is not None:
                    with torch.no_grad():
                        vocals_chunk = self.preprocess_transform(vocals_chunk.unsqueeze(0))  # (1, T)
                        non_vocals_chunk = self.preprocess_transform(non_vocals_chunk.unsqueeze(0))
                        orig_chunk = self.preprocess_transform(orig_chunk.unsqueeze(0))
                        
                chunk_folder = self.processed_root / f"{track_name}.chunk{i+1}"
                os.makedirs(chunk_folder, exist_ok=True)
                torch.save(vocals_chunk, chunk_folder / "vocals.pt")
                torch.save(non_vocals_chunk, chunk_folder / "non-vocals.pt")
                torch.save(orig_chunk, chunk_folder / "original.pt")

    def __len__(self):
        return len(self.chunk_dirs)

    def __getitem__(self, idx):
        chunk_dir = self.chunk_dirs[idx]
        vocals = torch.load(chunk_dir / "vocals.pt")
        non_vocals = torch.load(chunk_dir / "non-vocals.pt") 
        original = torch.load(chunk_dir / "original.pt")

        if self.runtime_transform:
            vocals = self.runtime_transform(vocals)
            non_vocals = self.runtime_transform(non_vocals)
            original = self.runtime_transform(original)

        return {
            "vocals": vocals.to(self.device),
            "non_vocals": non_vocals.to(self.device),
            "original": original.to(self.device),
            "chunk_name": chunk_dir.name,
        }


In [None]:
# ---------- Split dataset by tracks ----------

def split_dataset_by_tracks(dataset: GTZANStemsDataset, train_ratio=0.7, val_ratio=0.15, test_ratio=0.15, random_state=42):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6

    # Extract unique track names
    all_tracks = sorted({p.name.split(".chunk")[0] for p in dataset.chunk_dirs})
    
    train_tracks, temp_tracks = train_test_split(all_tracks, test_size=(1 - train_ratio), random_state=random_state)
    val_tracks, test_tracks = train_test_split(temp_tracks, test_size=test_ratio / (test_ratio + val_ratio), random_state=random_state)

    def get_chunk_indices(track_list):
        return [i for i, chunk_dir in enumerate(dataset.chunk_dirs) if chunk_dir.name.split(".chunk")[0] in track_list]

    train_dataset = Subset(dataset, get_chunk_indices(train_tracks))
    val_dataset = Subset(dataset, get_chunk_indices(val_tracks))
    test_dataset = Subset(dataset, get_chunk_indices(test_tracks))

    return train_dataset, val_dataset, test_dataset

# ---------- DataLoaders ----------

def create_dataloaders(preprocess=False, batch_size=16, num_workers=2):
    dataset = GTZANStemsDataset(preprocess=preprocess, preprocess_transform=ToLogMel())

    train_dataset, val_dataset, test_dataset = split_dataset_by_tracks(dataset)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, _ = create_dataloaders(batch_size=4)
# Get one batch from the training dataloader
batch = next(iter(train_loader))

# Print keys
print(batch.keys())
print("Vocals:", batch['vocals'].shape) # [batch_size, channels, n_mels, time] if using Mel
print("Non-vocals:", batch['non_vocals'].shape)
print("Original:", batch['original'].shape)
print("Chunk names:", batch['chunk_name'])   # List of strings

In [2]:
class CNNTransformer(nn.Module):
    def __init__(self, cnn_name="efficientnet_b0", transformer_dim=512, nhead=8, num_layers=4, device="gpu"):
        super().__init__()
        self.device = device
        
        # CNN backbone
        self.cnn = timm.create_model(cnn_name, pretrained=True, in_chans=1, num_classes=0, global_pool="avg")
        cnn_out_dim = self.cnn.num_features
        
        # Project CNN features to transformer dimension
        self.fc_proj = nn.Linear(cnn_out_dim, transformer_dim)
        
        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=transformer_dim, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Final embedding projection
        self.embedding = nn.Linear(transformer_dim, transformer_dim)

    def forward(self, x):
        """
        x: (B, 1, T, F) -> Mel spectrogram
        """
        B = x.size(0)
        # Flatten spectrogram to 2D image for CNN: (B, 1, T, F)
        cnn_feat = self.cnn(x)  # (B, cnn_out_dim)
        proj_feat = self.fc_proj(cnn_feat).unsqueeze(1)  # (B, 1, transformer_dim)
        trans_feat = self.transformer(proj_feat)  # (B, 1, transformer_dim)
        emb = self.embedding(trans_feat[:, 0, :])  # (B, transformer_dim)
        return emb

In [3]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0
    for batch in dataloader:
        optimizer.zero_grad()
        
        vocals = batch["vocals"].to(device)
        non_vocals = batch["non_vocals"].to(device)
        original = batch["original"].to(device)
        
        # Forward
        emb_vocals = model(vocals)
        emb_non_vocals = model(non_vocals)
        emb_mix = model(original)
        
        # Combine vocals + non-vocals for predicted mix embedding
        pred_emb = (emb_vocals + emb_non_vocals) / 2.0
        
        target = torch.ones(pred_emb.size(0), device=device)  # batch size
        loss = criterion(pred_emb, emb_mix, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    return running_loss / len(dataloader)

def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for batch in dataloader:
            vocals = batch["vocals"].to(device)
            non_vocals = batch["non_vocals"].to(device)
            original = batch["original"].to(device)
            
            emb_vocals = model(vocals)
            emb_non_vocals = model(non_vocals)
            emb_mix = model(original)
            pred_emb = (emb_vocals + emb_non_vocals) / 2.0
            
            loss = criterion(pred_emb, emb_mix)
            running_loss += loss.item()
    return running_loss / len(dataloader)

## To run this - use GPU
1. If run locally, first install CUDA Toolkit and cuDNN Library.
2. Create an environment through anaconda or other means.
3. Install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia.
4. Create a jupyter kernel and then start the environment through a jupyter notebook.
5. Start Training!

In [None]:
# Load dataloaders
train_loader, val_loader, _ = create_dataloaders(batch_size=4)

In [None]:
print(torch.cuda.is_available())  # Should return True
print(torch.cuda.device_count())  # Number of GPUs
print(torch.cuda.get_device_name(0))  # Name of the first GPU

device = "cuda" if torch.cuda.is_available() else "cpu"
    
# Initialize model
print(f"Using device: {device}")
model = CNNTransformer(device=device).to(device)
    
# Cosine embedding loss
criterion = nn.CosineEmbeddingLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
print("Starting training...")
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate(model, val_loader, criterion, device)
    print(f"Epoch {epoch}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    
# Save model
torch.save(model.state_dict(), "cnn_transformer_model.pt")
print("Model saved to cnn_transformer_model.pt")

Preprocessing GTZAN stems from D:\Users\ollet\Downloads\archive_1\Data\genres_stems and originals from D:\Users\ollet\Downloads\archive_1\Data\genres_original


  0%|          | 0/1000 [00:00<?, ?it/s]

⚠️ Skipping track reggae.00004 due to loading error: System error.
⚠️ Skipping track reggae.00005 due to loading error: System error.
dict_keys(['vocals', 'non_vocals', 'original', 'chunk_name'])
Vocals: torch.Size([4, 1, 64, 751])
Non-vocals: torch.Size([4, 1, 64, 751])
Original: torch.Size([4, 1, 64, 751])
Chunk names: ['metal.00026.chunk1', 'jazz.00067.chunk1', 'hiphop.00074.chunk2', 'metal.00021.chunk1']
