<a href="https://colab.research.google.com/github/gedman4b/AI/blob/main/sep_gpt_v0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os

jupyter_config_dir = os.path.expanduser('~/.jupyter')
os.makedirs(jupyter_config_dir, exist_ok=True)

config_path = os.path.join(jupyter_config_dir, 'jupyter_notebook_config.py')
with open(config_path, 'a') as f:
    f.write("c.FileContentsManager.save_checkpoints = False\n")

print(f"Added 'c.FileContentsManager.save_checkpoints = False' to {config_path}")

Added 'c.FileContentsManager.save_checkpoints = False' to /root/.jupyter/jupyter_notebook_config.py


In [20]:
from __future__ import annotations

import os
import math
import csv
import json
import time
import random
import argparse
import functools
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    import torchaudio
except Exception as e:
    raise RuntimeError("torchaudio is required. Install with: pip install torchaudio") from e


# -------------------------
# Repro / utils
# -------------------------
def set_seed(seed: int = 1337):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_dir(p: str):
    os.makedirs(p, exist_ok=True)


def human_time(seconds: float) -> str:
    if seconds < 60:
        return f"{seconds:.1f}s"
    if seconds < 3600:
        return f"{seconds/60:.1f}m"
    return f"{seconds/3600:.1f}h"

In [54]:
# -------------------------
# Audio + tokens
# -------------------------
@dataclass
class TokenSpec:
    # 8-bit magnitude tokens
    n_mag_tokens: int = 256  # 0..255

    # Special tokens (offset above mag tokens)
    BOS: int = 256
    MIX: int = 257
    mix: int = 258
    SEP: int = 259
    STEM: int = 260
    target: int = 261
    EOS: int = 262

    @property
    def n_special(self) -> int:
        return 7

    @property
    def vocab_size(self) -> int:
        return self.n_mag_tokens + self.n_special

    def stem_token(self, stem_index: int) -> int:
        # encode stem id as a single token above specials
        # (keeps file simple; supports up to a few dozen stems)
        return self.vocab_size + stem_index

    def total_vocab(self, n_stems: int) -> int:
        return self.vocab_size + n_stems


class MelReducer:
    """
    Converts linear-frequency magnitude (n_fft//2+1) <-> mel (n_mels)
    via a fixed mel filterbank matrix and pseudo-inverse.

    We use it on magnitudes (not power), and keep it simple.
    """
    def __init__(
        self, sample_rate: int, n_fft: int, n_mels: int = 64, f_min: float = 0.0, f_max: Optional[float] = None, device="cpu"
    ):
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.f_min = f_min
        self.f_max = f_max if f_max is not None else sample_rate / 2
        self.device = device

        n_freqs = n_fft // 2 + 1

        # Create mel filterbank. Torchaudio usually returns (n_mels, n_freqs).
        fb_raw = torchaudio.functional.melscale_fbanks(
            n_freqs=n_freqs,
            f_min=self.f_min,
            f_max=self.f_max,
            n_mels=self.n_mels,
            sample_rate=self.sample_rate,
            norm="slaney",
            mel_scale="htk",
        )
        fb_raw = fb_raw.to(torch.float32).to(device)

        # Robustly ensure mel_forward_matrix is (n_freqs, n_mels) i.e. (129, 16)
        if fb_raw.shape[0] == self.n_mels:
            # shape is (n_mels, n_freqs) -> transpose to (n_freqs, n_mels)
            self.mel_forward_matrix = fb_raw.t()
        else:
            # assume shape is already (n_freqs, n_mels)
            self.mel_forward_matrix = fb_raw

        # Pseudo-inverse: (n_mels, n_freqs)
        self.mel_inverse_matrix = torch.linalg.pinv(self.mel_forward_matrix)

    def linear_to_mel(self, mag_lin: torch.Tensor) -> torch.Tensor:
        # mag_lin: (..., n_freqs)
        # self.mel_forward_matrix: (n_freqs, n_mels)
        # Result: (..., n_mels)
        return torch.matmul(mag_lin, self.mel_forward_matrix)

    def mel_to_linear(self, mag_mel: torch.Tensor) -> torch.Tensor:
        # mag_mel: (..., n_mels)
        # self.mel_inverse_matrix: (n_mels, n_freqs)
        # Result: (..., n_freqs)
        return torch.matmul(mag_mel, self.mel_inverse_matrix)


class AudioTokenizer:
    """
    Tokenize waveform -> STFT magnitude -> log1p -> mel -> quantize (0..255)
    Also supports inverse: tokens -> mel -> linear -> ISTFT with mixture phase.
    """
    def __init__(
        self,
        sample_rate: int = 44100,
        n_fft: int = 256,
        hop: int = 256,
        n_mels: int = 64,
        token_spec: Optional[TokenSpec] = None,
        device: str = "cpu",
        log_eps: float = 1e-8,
    ):
        self.sr = sample_rate
        self.n_fft = n_fft
        self.hop = hop
        self.win_length = n_fft  # rectangular window equivalent
        self.window = torch.ones(self.win_length, device=device)  # rectangular
        self.n_mels = n_mels
        self.spec = token_spec or TokenSpec()
        self.device = device
        self.log_eps = log_eps
        self.reducer = MelReducer(sample_rate, n_fft, n_mels=n_mels, device=device)

        # Quantization calibration (global fixed baseline)
        # log1p magnitudes tend to live in [0, ~something]. We clamp to a range.
        # These can be tuned; baseline values work OK across many music datasets.
        self.q_min = 0.0
        self.q_max = 6.0  # log1p(mag) clamp upper bound

    def stft(self, wav: torch.Tensor) -> torch.Tensor:
        # wav: (T,)
        return torch.stft(
            wav,
            n_fft=self.n_fft,
            hop_length=self.hop,
            win_length=self.win_length,
            window=self.window,
            center=False,
            return_complex=True,
        )

    def istft(self, stft_c: torch.Tensor, length: int) -> torch.Tensor:
        return torch.istft(
            stft_c,
            n_fft=self.n_fft,
            hop_length=self.hop,
            win_length=self.win_length,
            window=self.window,
            center=False,
            length=length,
        )

    @torch.no_grad()
    def wav_to_tokens(self, wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
          tokens: (frames, n_mels) int64 in [0..255]
          mix_phase: complex phase from STFT (frames, freqs)
        """
        wav = wav.to(self.device).to(torch.float32)
        X = self.stft(wav)  # (freqs, frames) complex
        X = X.transpose(0, 1).contiguous()  # (frames, freqs)
        mag = torch.abs(X).clamp_min(self.log_eps)  # (frames, freqs)
        phase = X / mag  # unit complex, (frames, freqs)

        # linear -> mel on magnitude
        mag_mel = self.reducer.linear_to_mel(mag)  # (frames, mels)
        logmag = torch.log1p(mag_mel)  # (frames, mels)

        # quantize to 0..255
        q = (logmag.clamp(self.q_min, self.q_max) - self.q_min) / (self.q_max - self.q_min)
        tokens = torch.round(q * (self.spec.n_mag_tokens - 1)).to(torch.int64)
        return tokens, phase

    @torch.no_grad()
    def tokens_to_wav_with_phase(self, tokens: torch.Tensor, phase: torch.Tensor, length: int) -> torch.Tensor:
        """
        tokens: (frames, mels) int64 0..255
        phase: (frames, freqs) unit complex (from mixture)
        """
        tokens = tokens.to(self.device)
        # dequantize
        q = tokens.to(torch.float32) / (self.spec.n_mag_tokens - 1)
        logmag = q * (self.q_max - self.q_min) + self.q_min
        mag_mel = torch.expm1(logmag).clamp_min(0.0)  # (frames, mels)

        mag_lin = self.reducer.mel_to_linear(mag_mel)  # (frames, freqs)
        mag_lin = mag_lin.clamp_min(self.log_eps)
        X = (mag_lin * phase).transpose(0, 1).contiguous()  # (freqs, frames)
        wav = self.istft(X, length=length)
        return wav

In [55]:
# -------------------------
# Model: decoder-only GPT
# -------------------------
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        assert d_model % n_heads == 0
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_head = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.proj = nn.Linear(d_model, d_model, bias=False)
        self.drop = nn.Dropout(dropout)

        # registered buffer for causal mask will be created dynamically by sequence length

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, T, C)
        B, T, C = x.shape
        qkv = self.qkv(x)  # (B, T, 3C)
        q, k, v = qkv.chunk(3, dim=-1)

        q = q.view(B, T, self.n_heads, self.d_head).transpose(1, 2)  # (B, H, T, Dh)
        k = k.view(B, T, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_head).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_head)  # (B, H, T, T)
        # causal mask
        mask = torch.triu(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=1)
        att = att.masked_fill(mask, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.drop(att)

        y = att @ v  # (B, H, T, Dh)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.proj(y)
        y = self.drop(y)
        return y


class MLP(nn.Module):
    def __init__(self, d_model: int, dropout: float, mult: int = 4):
        super().__init__()
        self.fc1 = nn.Linear(d_model, mult * d_model)
        self.fc2 = nn.Linear(mult * d_model, d_model)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self, d_model: int, n_heads: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = CausalSelfAttention(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, dropout)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


class GPTSeparator(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 512,
        n_layers: int = 8,
        n_heads: int = 8,
        dropout: float = 0.1,
        max_seq_len: int = 4096,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([Block(d_model, n_heads, dropout) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Linear, nn.Embedding)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if isinstance(m, nn.Linear) and m.bias is not None:
            nn.init.zeros_(m.bias)

    def forward(self, idx: torch.Tensor) -> torch.Tensor:
        # idx: (B, T)
        B, T = idx.shape
        if T > self.max_seq_len:
            raise ValueError(f"Sequence length {T} exceeds max_seq_len {self.max_seq_len}")

        pos = torch.arange(0, T, device=idx.device, dtype=torch.long).unsqueeze(0)
        x = self.tok_emb(idx) + self.pos_emb(pos)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.head(x)  # (B, T, vocab)
        return logits


In [56]:
# -------------------------
# Dataset
# -------------------------
def list_tracks(dataset_root: str) -> List[str]:
    tracks = []
    for name in sorted(os.listdir(dataset_root)):
        if name.startswith('.'):  # Ignore hidden directories like .ipynb_checkpoints
            continue
        p = os.path.join(dataset_root, name)
        if os.path.isdir(p):
            tracks.append(p)
    return tracks


def load_wav_mono(path: str, target_sr: int) -> torch.Tensor:
    wav, sr = torchaudio.load(path)  # (C, T)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    wav = wav.mean(dim=0)  # mono
    return wav


def stems_in_track(track_dir: str) -> List[str]:
    # all .wav except mixture.wav
    wavs = [f for f in os.listdir(track_dir) if f.lower().endswith(".wav")]
    stems = [f for f in wavs if f.lower() != "mixture.wav"]
    stems.sort()
    return stems

In [57]:
class StemSeparationDataset(torch.utils.data.Dataset):
    """
    Produces windowed examples:
      - mixture waveform window
      - target stem waveform window
    Tokenization happens in the collate_fn for speed batching (or in __getitem__ if desired).
    """
    def __init__(
        self,
        dataset_root: str,
        sample_rate: int = 44100,
        window_sec: float = 0.1,
        mix_from_stems: bool = True,
        explicit_tracks: Optional[List[str]] = None,
        stem_names: Optional[List[str]] = None,
        seed: int = 1337,
    ):
        super().__init__()
        self.dataset_root = dataset_root
        self.sr = sample_rate
        self.window_sec = window_sec
        self.win_len = int(round(window_sec * sample_rate))
        self.mix_from_stems = mix_from_stems
        self.rng = random.Random(seed)

        self.tracks = explicit_tracks if explicit_tracks is not None else list_tracks(dataset_root)
        if len(self.tracks) == 0:
            raise ValueError(f"No tracks found in {dataset_root}")

        # Determine stems universe
        # If stem_names provided, we enforce that; otherwise derive from first track and keep intersection across tracks.
        if stem_names is None:
            first = stems_in_track(self.tracks[0])
            if len(first) == 0:
                raise ValueError(f"No stems found in {self.tracks[0]}")
            self.stem_names = first
        else:
            self.stem_names = list(stem_names)

        # Precompute lengths and valid start positions per track (at least one window)
        self.track_meta = []
        for td in self.tracks:
            mix_path = os.path.join(td, "mixture.wav")
            if not os.path.exists(mix_path) and not self.mix_from_stems:
                continue

            # choose a reference stem file for length (or mixture)
            ref_path = mix_path if os.path.exists(mix_path) else os.path.join(td, self.stem_names[0])
            if not os.path.exists(ref_path):
                continue
            wav = load_wav_mono(ref_path, self.sr)
            T = wav.shape[0]
            if T < self.win_len:
                continue
            self.track_meta.append((td, T))
        if len(self.track_meta) == 0:
            raise ValueError("No usable tracks (check mixture.wav presence or enable mix_from_stems=True).")

        # Define dataset length as a large number; windows sampled randomly for each item
        # This avoids epoch-definition confusion and improves mixing.
        self.virtual_len = 200_000  # baseline; training is step-based

    def __len__(self):
        return self.virtual_len

    def __getitem__(self, idx: int):
        # Sample a random track and random stem and random window start
        td, T = self.rng.choice(self.track_meta)
        stem_name = self.rng.choice(self.stem_names)
        stem_path = os.path.join(td, stem_name)

        # choose start
        start = self.rng.randint(0, T - self.win_len)
        end = start + self.win_len

        # load target stem window
        y = load_wav_mono(stem_path, self.sr)[start:end]

        # mixture window
        mix_path = os.path.join(td, "mixture.wav")
        if os.path.exists(mix_path):
            x = load_wav_mono(mix_path, self.sr)[start:end]
        else:
            # synthesize mixture from all available stems (intersection)
            xs = []
            for s in self.stem_names:
                p = os.path.join(td, s)
                if os.path.exists(p):
                    xs.append(load_wav_mono(p, self.sr)[start:end])
            if len(xs) == 0:
                x = y.clone()
            else:
                x = torch.stack(xs, dim=0).sum(dim=0)

        # simple gain normalization (baseline)
        x = x / (x.abs().max().clamp_min(1e-6))
        y = y / (y.abs().max().clamp_min(1e-6))

        return {
            "track": os.path.basename(td),
            "stem": stem_name,
            "stem_index": self.stem_names.index(stem_name),
            "mixture_wav": x,
            "target_wav": y,
        }


In [69]:
def build_sequence_tokens(
    spec: TokenSpec,
    mix_tokens_2d: torch.Tensor,     # (F, M) with values 0..255
    tgt_tokens_2d: torch.Tensor,     # (F, M) with values 0..255
    stem_index: int,
    n_stems: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Flatten 2D tokens to 1D and build [context + target] sequence.
    Returns:
      seq: (T,) int64 full input sequence including EOS
      target_mask: (T,) bool mask where loss should be applied
    """
    # Determine device from inputs
    device = mix_tokens_2d.device

    mix_flat = mix_tokens_2d.reshape(-1).to(torch.int64)
    tgt_flat = tgt_tokens_2d.reshape(-1).to(torch.int64)

    stem_tok = spec.stem_token(stem_index)  # distinct token
    vocab_total = spec.total_vocab(n_stems)
    if stem_tok >= vocab_total:
        raise ValueError("stem token out of range; increase n_stems or adjust token mapping")

    seq = torch.cat([
        torch.tensor([spec.BOS, spec.MIX, spec.mix], dtype=torch.int64, device=device),
        mix_flat,
        torch.tensor([spec.SEP, spec.STEM, stem_tok, spec.target], dtype=torch.int64, device=device),
        tgt_flat,
        torch.tensor([spec.EOS], dtype=torch.int64, device=device),
    ], dim=0)

    # Loss only on target tokens (tgt_flat positions)
    mask = torch.zeros_like(seq, dtype=torch.bool)
    tgt_start = 3 + mix_flat.numel() + 4  # after BOS MIX mix + mix_flat + SEP STEM stem_tok target
    tgt_end = tgt_start + tgt_flat.numel()
    mask[tgt_start:tgt_end] = True
    return seq, mask


def collate_batch(
    batch: List[dict],
    tokenizer: AudioTokenizer,
    spec: TokenSpec,
    n_stems: int,
    max_len: int,
) -> Dict[str, torch.Tensor]:
    # tokenize each item and build sequence; then pad
    seqs = []
    masks = []
    meta = []
    for item in batch:
        mix_tok2d, mix_phase = tokenizer.wav_to_tokens(item["mixture_wav"])
        tgt_tok2d, _ = tokenizer.wav_to_tokens(item["target_wav"])
        seq, mask = build_sequence_tokens(
            spec, mix_tok2d, tgt_tok2d, item["stem_index"], n_stems
        )
        if seq.numel() > max_len:
            seq = seq[:max_len]
            mask = mask[:max_len]
        seqs.append(seq)
        masks.append(mask)
        meta.append((item["track"], item["stem"]))

    T = max(s.numel() for s in seqs)
    T = min(T, max_len)
    pad_id = spec.EOS  # pad with EOS

    # Use the device of the generated sequences
    device = seqs[0].device if seqs else torch.device('cpu')

    x = torch.full((len(batch), T), pad_id, dtype=torch.int64, device=device)
    m = torch.zeros((len(batch), T), dtype=torch.bool, device=device)

    for i, (s, mk) in enumerate(zip(seqs, masks)):
        t = min(s.numel(), T)
        x[i, :t] = s[:t]
        m[i, :t] = mk[:t]

    return {"idx": x, "loss_mask": m, "meta": meta}

In [70]:
# -------------------------
# Metrics: SI-SDR
# -------------------------
def si_sdr(est: torch.Tensor, ref: torch.Tensor, eps: float = 1e-8) -> float:
    """
    Scale-Invariant SDR in dB for 1D tensors.
    """
    est = est.detach().cpu().float()
    ref = ref.detach().cpu().float()
    ref_energy = torch.sum(ref * ref) + eps
    alpha = torch.sum(est * ref) / ref_energy
    s_target = alpha * ref
    e_noise = est - s_target
    ratio = (torch.sum(s_target * s_target) + eps) / (torch.sum(e_noise * e_noise) + eps)
    return float(10.0 * torch.log10(ratio))


In [87]:
# -------------------------
# Training / validation
# -------------------------
@dataclass
class TrainConfig:
    dataset_root: str
    out_dir: str
    sample_rate: int = 44100
    window_sec: float = 0.1
    n_fft: int = 256
    hop: int = 256
    n_mels: int = 64

    d_model: int = 512
    n_layers: int = 8
    n_heads: int = 8
    dropout: float = 0.1
    max_seq_len: int = 4096

    batch_size: int = 8
    grad_accum: int = 8  # effective batch = batch_size * grad_accum
    lr: float = 3e-5
    min_lr: float = 3e-5
    weight_decay: float = 0.05
    betas: Tuple[float, float] = (0.9, 0.95)
    eps: float = 1e-8
    clip_norm: float = 1.0
    warmup_steps: int = 1000
    total_steps: int = 5000

    val_every: int = 1000
    ckpt_every: int = 2500
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 1337

    mix_from_stems: bool = True
    num_workers: int = 2

    # evaluation
    sisdr_eval_every: int = 5000
    sisdr_patience_evals: int = 3
    sisdr_min_delta_db: float = 0.1

    # metrics output
    metrics_dir: Optional[str] = None


In [88]:
def cosine_lr(step: int, cfg: TrainConfig) -> float:
    if step < cfg.warmup_steps:
        return cfg.lr * (step + 1) / max(1, cfg.warmup_steps)
    t = (step - cfg.warmup_steps) / max(1, (cfg.total_steps - cfg.warmup_steps))
    t = min(max(t, 0.0), 1.0)
    # cosine from lr -> min_lr
    return cfg.min_lr + 0.5 * (cfg.lr - cfg.min_lr) * (1.0 + math.cos(math.pi * t))


def masked_ce_loss(logits: torch.Tensor, targets: torch.Tensor, loss_mask: torch.Tensor) -> torch.Tensor:
    """
    logits: (B, T, V)
    targets: (B, T)
    loss_mask: (B, T) bool - only positions True contribute
    """
    B, T, V = logits.shape
    # shift: predict targets[:, 1:] from logits[:, :-1]
    logits_s = logits[:, :-1, :].contiguous()
    targets_s = targets[:, 1:].contiguous()
    mask_s = loss_mask[:, 1:].contiguous()

    # flatten
    logits_f = logits_s.view(-1, V)
    targets_f = targets_s.view(-1)
    mask_f = mask_s.view(-1)

    if mask_f.sum().item() == 0:
        return torch.tensor(0.0, device=logits.device)

    loss = F.cross_entropy(logits_f[mask_f], targets_f[mask_f])
    return loss


def save_ckpt(path: str, model: nn.Module, optim: torch.optim.Optimizer, step: int, cfg: TrainConfig, extra: dict):
    payload = {
        "step": step,
        "cfg": cfg.__dict__,
        "model": model.state_dict(),
        "optim": optim.state_dict(),
        "extra": extra,
    }
    torch.save(payload, path)


def load_ckpt(path: str, model: nn.Module, optim: Optional[torch.optim.Optimizer] = None):
    ckpt = torch.load(path, map_location="cpu")
    model.load_state_dict(ckpt["model"])
    if optim is not None and "optim" in ckpt:
        optim.load_state_dict(ckpt["optim"])
    return ckpt


@torch.no_grad()
def validate_model(model: nn.Module, loader, spec: TokenSpec, device: str) -> float:
    model.eval()
    losses = []
    for batch in loader:
        idx = batch["idx"].to(device)
        m = batch["loss_mask"].to(device)
        logits = model(idx)
        loss = masked_ce_loss(logits, idx, m)
        losses.append(float(loss.detach().cpu()))
    model.train()
    return float(np.mean(losses)) if losses else float("nan")


@torch.no_grad()
def greedy_decode_target_tokens(
    model: nn.Module,
    prompt: torch.Tensor,      # (T0,)
    n_decode: int,
    device: str,
) -> torch.Tensor:
    """
    Greedy decode n_decode tokens following prompt. Returns full sequence (prompt + decoded).
    """
    model.eval()
    seq = prompt.clone().to(device).unsqueeze(0)  # (1, T)
    for _ in range(n_decode):
        logits = model(seq)  # (1, T, V)
        next_tok = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)  # (1,1)
        seq = torch.cat([seq, next_tok], dim=1)
        if seq.shape[1] >= model.max_seq_len:
            break
    model.train()
    return seq.squeeze(0).detach().cpu()


In [89]:
# -------------------------
# Inference: overlap-add
# -------------------------
def hann_fade(win_len: int, device: str) -> torch.Tensor:
    # Half-overlap Hann-ish fade for OLA smoothing
    w = torch.hann_window(win_len, periodic=False, device=device)
    return w


@torch.no_grad()
def separate_window(
    model: GPTSeparator,
    tokenizer: AudioTokenizer,
    spec: TokenSpec,
    mix_wav: torch.Tensor,
    stem_index: int,
    n_stems: int,
    device: str,
) -> torch.Tensor:
    """
    Separate one window of waveform using mixture phase.
    """
    mix_tok2d, phase = tokenizer.wav_to_tokens(mix_wav)
    mix_flat = mix_tok2d.reshape(-1)

    # Build prompt up to 'target' token (inclusive)
    stem_tok = spec.stem_token(stem_index)
    prompt = torch.cat([
        torch.tensor([spec.BOS, spec.MIX, spec.mix], dtype=torch.int64),
        mix_flat.to(torch.int64),
        torch.tensor([spec.SEP, spec.STEM, stem_tok, spec.target], dtype=torch.int64),
    ], dim=0)

    # Decode exactly the number of target tokens needed
    n_tgt = mix_flat.numel()  # same shape as mix in baseline
    full = greedy_decode_target_tokens(model, prompt, n_tgt + 1, device=device)  # +EOS token maybe
    # Extract decoded target tokens
    decoded = full[prompt.numel():prompt.numel() + n_tgt]
    decoded2d = decoded.view(mix_tok2d.shape[0], mix_tok2d.shape[1]).clamp(0, spec.n_mag_tokens - 1)

    # Reconstruct waveform using mixture phase
    wav_est = tokenizer.tokens_to_wav_with_phase(decoded2d, phase, length=mix_wav.numel())
    return wav_est


@torch.no_grad()
def overlap_add_separate(
    model: GPTSeparator,
    tokenizer: AudioTokenizer,
    spec: TokenSpec,
    mix_wav: torch.Tensor,
    stem_index: int,
    n_stems: int,
    window_sec: float,
    hop_sec: float,
    device: str,
) -> torch.Tensor:
    """
    Full-file separation by OLA with Hann fade.
    """
    model.eval()
    sr = tokenizer.sr
    win_len = int(round(window_sec * sr))
    hop_len = int(round(hop_sec * sr))
    if hop_len <= 0 or win_len <= 0:
        raise ValueError("Invalid window/hop for overlap-add.")
    if mix_wav.numel() < win_len:
        # pad
        pad = win_len - mix_wav.numel()
        mix_wav = torch.cat([mix_wav, torch.zeros(pad)], dim=0)

    w = hann_fade(win_len, device=device).detach().cpu()
    out = torch.zeros_like(mix_wav)
    norm = torch.zeros_like(mix_wav)

    for start in range(0, mix_wav.numel() - win_len + 1, hop_len):
        chunk = mix_wav[start:start + win_len].to(device)
        est = separate_window(model, tokenizer, spec, chunk, stem_index, n_stems, device=device).detach().cpu()
        out[start:start + win_len] += est * w
        norm[start:start + win_len] += w

    out = out / norm.clamp_min(1e-6)
    model.train()
    return out


In [90]:
# -------------------------
# Evaluation
# -------------------------
@torch.no_grad()
def evaluate_sisdr_windows(
    model: GPTSeparator,
    dataset: StemSeparationDataset,
    tokenizer: AudioTokenizer,
    spec: TokenSpec,
    stems: List[str],
    n_windows: int,
    out_csv: str,
    device: str,
):
    """
    Window-level evaluation: greedy decode each sampled window and compute SI-SDR and CE on masked tokens.
    """
    ensure_dir(os.path.dirname(out_csv) or ".")
    model.eval()

    rows = []
    for i in range(n_windows):
        item = dataset[i]  # dataset randomized anyway
        mix = item["mixture_wav"].to(device)
        tgt = item["target_wav"].to(device)

        mix_tok2d, phase = tokenizer.wav_to_tokens(mix)
        tgt_tok2d, _ = tokenizer.wav_to_tokens(tgt)

        seq, mask = build_sequence_tokens(spec, mix_tok2d, tgt_tok2d, item["stem_index"], len(stems))
        idx = seq.unsqueeze(0).to(device)
        m = mask.unsqueeze(0).to(device)

        logits = model(idx)
        ce = float(masked_ce_loss(logits, idx, m).detach().cpu())

        # decode target tokens and reconstruct
        est = separate_window(model, tokenizer, spec, mix, item["stem_index"], len(stems), device=device)
        s = si_sdr(est, tgt)

        rows.append({
            "track": item["track"],
            "start_sample": -1,  # unknown in this virtual sampler baseline
            "stem": item["stem"],
            "sisdr_db": s,
            "ce_loss": ce,
            "num_masked_tokens": int(mask.sum().item()),
        })

    # write
    with open(out_csv, "w", newline="", encoding="utf-8") as f:
        wcsv = csv.DictWriter(f, fieldnames=list(rows[0].keys()) if rows else [])
        wcsv.writeheader()
        for r in rows:
            wcsv.writerow(r)

    model.train()
    return rows


@torch.no_grad()
def evaluate_sisdr_tracks(
    model: GPTSeparator,
    dataset_root: str,
    tokenizer: AudioTokenizer,
    spec: TokenSpec,
    stems: List[str],
    out_dir: str,
    device: str,
    window_sec: float,
    hop_sec: float,
    max_tracks: Optional[int] = None,
):
    """
    Track-level evaluation: full-file separation per stem, SI-SDR per stem and per track.
    """
    ensure_dir(out_dir)
    tracks = list_tracks(dataset_root)
    if max_tracks is not None:
        tracks = tracks[:max_tracks]

    per_stem_rows = []
    per_track_rows = []

    for td in tracks:
        track_name = os.path.basename(td)
        mix_path = os.path.join(td, "mixture.wav")
        if not os.path.exists(mix_path):
            continue
        mix = load_wav_mono(mix_path, tokenizer.sr)

        stem_scores = []
        for si, stem in enumerate(stems):
            stem_path = os.path.join(td, stem)
            if not os.path.exists(stem_path):
                continue
            ref = load_wav_mono(stem_path, tokenizer.sr)

            est = overlap_add_separate(
                model, tokenizer, spec, mix, si, len(stems),
                window_sec=window_sec, hop_sec=hop_sec, device=device
            )

            L = min(est.numel(), ref.numel())
            s = si_sdr(est[:L], ref[:L])
            stem_scores.append(s)

            per_stem_rows.append({
                "track": track_name,
                "stem": stem,
                "sisdr_db": s,
            })

        if stem_scores:
            per_track_rows.append({
                "track": track_name,
                "sisdr_db_mean": float(np.mean(stem_scores)),
                "sisdr_db_min": float(np.min(stem_scores)),
                "sisdr_db_max": float(np.max(stem_scores)),
            })

    # export CSVs
    per_stem_csv = os.path.join(out_dir, "sisdr_per_stem.csv")
    per_track_csv = os.path.join(out_dir, "sisdr_per_track.csv")

    if per_stem_rows:
        with open(per_stem_csv, "w", newline="", encoding="utf-8") as f:
            wcsv = csv.DictWriter(f, fieldnames=list(per_stem_rows[0].keys()))
            wcsv.writeheader()
            wcsv.writerows(per_stem_rows)

    if per_track_rows:
        with open(per_track_csv, "w", newline="", encoding="utf-8") as f:
            wcsv = csv.DictWriter(f, fieldnames=list(per_track_rows[0].keys()))
            wcsv.writeheader()
            wcsv.writerows(per_track_rows)

    return per_stem_csv, per_track_csv


In [91]:
# -------------------------
# Train loop
# -------------------------
def split_dataset(tracks: List[str], val_frac: float = 0.1, seed: int = 1337) -> Tuple[List[str], List[str]]:
    rng = random.Random(seed)
    tracks = tracks[:]
    rng.shuffle(tracks)
    n_val = max(1, int(round(len(tracks) * val_frac)))
    return tracks[n_val:], tracks[:n_val]

def _global_collate_fn(batch: List[dict], tokenizer: 'AudioTokenizer', spec: TokenSpec, n_stems: int, max_len: int):
    return collate_batch(batch, tokenizer=tokenizer, spec=spec, n_stems=n_stems, max_len=max_len)

def train_model(cfg: TrainConfig):
    set_seed(cfg.seed)
    ensure_dir(cfg.out_dir)
    ckpt_dir = os.path.join(cfg.out_dir, "checkpoints")
    ensure_dir(ckpt_dir)
    metrics_dir = cfg.metrics_dir or os.path.join(cfg.out_dir, "metrics")
    ensure_dir(metrics_dir)

    # Dataset split by track
    all_tracks = list_tracks(cfg.dataset_root)
    train_tracks, val_tracks = split_dataset(all_tracks, val_frac=0.1, seed=cfg.seed)

    train_ds = StemSeparationDataset(
        cfg.dataset_root, sample_rate=cfg.sample_rate, window_sec=cfg.window_sec,
        mix_from_stems=cfg.mix_from_stems, explicit_tracks=train_tracks, seed=cfg.seed
    )
    val_ds = StemSeparationDataset(
        cfg.dataset_root, sample_rate=cfg.sample_rate, window_sec=cfg.window_sec,
        mix_from_stems=cfg.mix_from_stems, explicit_tracks=val_tracks, seed=cfg.seed + 1
    )

    stems = train_ds.stem_names
    n_stems = len(stems)
    spec = TokenSpec()

    device = cfg.device
    tokenizer = AudioTokenizer(
        sample_rate=cfg.sample_rate, n_fft=cfg.n_fft, hop=cfg.hop, n_mels=cfg.n_mels,
        token_spec=spec, device=device
    )

    vocab_total = spec.total_vocab(n_stems)

    model = GPTSeparator(
        vocab_size=vocab_total,
        d_model=cfg.d_model,
        n_layers=cfg.n_layers,
        n_heads=cfg.n_heads,
        dropout=cfg.dropout,
        max_seq_len=cfg.max_seq_len,
    ).to(device)

    optim = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.lr,
        betas=cfg.betas,
        eps=cfg.eps,
        weight_decay=cfg.weight_decay,
    )

    # Load latest checkpoint if exists
    latest_path = os.path.join(ckpt_dir, "ckpt_latest.pt")
    start_step = 0
    best_sisdr = -1e9
    best_step = -1
    sisdr_no_improve = 0
    if os.path.exists(latest_path):
        ck = load_ckpt(latest_path, model, optim)
        start_step = int(ck.get("step", 0))
        extra = ck.get("extra", {})
        best_sisdr = float(extra.get("best_sisdr", best_sisdr))
        best_step = int(extra.get("best_step", best_step))
        sisdr_no_improve = int(extra.get("sisdr_no_improve", 0))
        print(f"Resumed from {latest_path} at step={start_step}")

    # DataLoaders
    train_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=cfg.batch_size, shuffle=False,  # dataset samples randomly
        num_workers=cfg.num_workers,
        collate_fn=functools.partial(_global_collate_fn, tokenizer=tokenizer, spec=spec, n_stems=n_stems, max_len=cfg.max_seq_len),
        pin_memory=False
    )
    val_loader = torch.utils.data.DataLoader(
        val_ds, batch_size=cfg.batch_size, shuffle=False,
        num_workers=max(0, cfg.num_workers // 2),
        collate_fn=functools.partial(_global_collate_fn, tokenizer=tokenizer, spec=spec, n_stems=n_stems, max_len=cfg.max_seq_len),
        pin_memory=False
    )

    # Training
    model.train()
    t0 = time.time()
    running = 0.0
    step = start_step

    # simple iterator cycling
    train_iter = iter(train_loader)

    # Save config snapshot
    with open(os.path.join(cfg.out_dir, "config.json"), "w", encoding="utf-8") as f:
        json.dump(cfg.__dict__, f, indent=2)
    with open(os.path.join(cfg.out_dir, "stems.json"), "w", encoding="utf-8") as f:
        json.dump(stems, f, indent=2)

    print(f"Device={device} | stems={stems} | effective_batch={cfg.batch_size * cfg.grad_accum}")

    while step < cfg.total_steps:
        # LR schedule
        lr_now = cosine_lr(step, cfg)
        for pg in optim.param_groups:
            pg["lr"] = lr_now

        # fetch batch
        try:
            batch = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            batch = next(train_iter)

        idx = batch["idx"].to(device, non_blocking=True)
        m = batch["loss_mask"].to(device, non_blocking=True)

        logits = model(idx)
        loss = masked_ce_loss(logits, idx, m) / cfg.grad_accum
        loss.backward()

        if (step + 1) % cfg.grad_accum == 0:
            if cfg.clip_norm is not None and cfg.clip_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_norm)
            optim.step()
            optim.zero_grad(set_to_none=True)

        running += float(loss.detach().cpu()) * cfg.grad_accum  # unscale for reporting

        # validate CE
        if (step + 1) % cfg.val_every == 0:
            val_ce = validate_model(model, val_loader, spec, device=device)
            avg = running / cfg.val_every
            running = 0.0
            elapsed = time.time() - t0
            print(
                f"step {step+1}/{cfg.total_steps} | lr {lr_now:.2e} | "
                f"train_ce {avg:.4f} | val_ce {val_ce:.4f} | time {human_time(elapsed)}"
            )

        # SI-SDR eval (window-level quick proxy)
        if (step + 1) % cfg.sisdr_eval_every == 0:
            out_csv = os.path.join(metrics_dir, "sisdr_windows.csv")
            rows = evaluate_sisdr_windows(
                model, val_ds, tokenizer, spec, stems,
                n_windows=64, out_csv=out_csv, device=device
            )
            mean_s = float(np.mean([r["sisdr_db"] for r in rows])) if rows else float("-inf")
            print(f"  SI-SDR(window mean over 64): {mean_s:.3f} dB  (csv: {out_csv})")

            # early stopping on SI-SDR improvements
            if mean_s >= best_sisdr + cfg.sisdr_min_delta_db:
                best_sisdr = mean_s
                best_step = step + 1
                sisdr_no_improve = 0
                save_ckpt(
                    os.path.join(ckpt_dir, "ckpt_best.pt"),
                    model, optim, step + 1, cfg,
                    extra={"best_sisdr": best_sisdr, "best_step": best_step, "sisdr_no_improve": sisdr_no_improve}
                )
                print(f"  âœ… New best SI-SDR: {best_sisdr:.3f} dB at step {best_step}")
            else:
                sisdr_no_improve += 1
                print(f"  No SI-SDR improvement (patience {sisdr_no_improve}/{cfg.sisdr_patience_evals})")
                if sisdr_no_improve >= cfg.sisdr_patience_evals:
                    print(f"ðŸ›‘ Early stopping: SI-SDR plateaued. Best={best_sisdr:.3f} dB at step {best_step}")
                    break

        # checkpoints
        if (step + 1) % cfg.ckpt_every == 0:
            save_ckpt(
                latest_path, model, optim, step + 1, cfg,
                extra={"best_sisdr": best_sisdr, "best_step": best_step, "sisdr_no_improve": sisdr_no_improve}
            )
            print(f"  Saved checkpoint: {latest_path}")

        step += 1

    # final save
    save_ckpt(
        latest_path, model, optim, step, cfg,
        extra={"best_sisdr": best_sisdr, "best_step": best_step, "sisdr_no_improve": sisdr_no_improve}
    )
    print(f"Done. Final step={step}. Best SI-SDR(window mean)={best_sisdr:.3f} dB at step={best_step}.")

In [92]:
# -------------------------
# Notebook-friendly helpers
# -------------------------
def init_model(ckpt_path: str, device: Optional[str] = None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    ck = torch.load(ckpt_path, map_location="cpu")
    cfgd = ck.get("cfg", {})
    stems_path = os.path.join(os.path.dirname(os.path.dirname(ckpt_path)), "stems.json")
    stems = json.load(open(stems_path, "r", encoding="utf-8")) if os.path.exists(stems_path) else []

    cfg = TrainConfig(dataset_root=cfgd.get("dataset_root", ""), out_dir=cfgd.get("out_dir", ""),
                      sample_rate=cfgd.get("sample_rate", 44100),
                      window_sec=cfgd.get("window_sec", 0.1),
                      n_fft=cfgd.get("n_fft", 256), hop=cfgd.get("hop", 256),
                      n_mels=cfgd.get("n_mels", 64),
                      d_model=cfgd.get("d_model", 512), n_layers=cfgd.get("n_layers", 8), n_heads=cfgd.get("n_heads", 8),
                      dropout=cfgd.get("dropout", 0.1), max_seq_len=cfgd.get("max_seq_len", 4096),
                      device=device)

    spec = TokenSpec()
    vocab_total = spec.total_vocab(len(stems) if stems else 4)  # fallback
    model = GPTSeparator(vocab_size=vocab_total, d_model=cfg.d_model, n_layers=cfg.n_layers,
                         n_heads=cfg.n_heads, dropout=cfg.dropout, max_seq_len=cfg.max_seq_len).to(device)

    model.load_state_dict(ck["model"])
    model.eval()

    tokenizer = AudioTokenizer(sample_rate=cfg.sample_rate, n_fft=cfg.n_fft, hop=cfg.hop, n_mels=cfg.n_mels,
                               token_spec=spec, device=device)

    return model, tokenizer, spec, stems, cfg


def run_separation(model, tokenizer, spec, stems: List[str], mixture_wav_path: str, out_wav_path: str,
                   stem_name: str = "vocals", window_sec: float = 0.1, hop_sec: float = 0.05, device: Optional[str] = None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")
    mix = load_wav_mono(mixture_wav_path, tokenizer.sr)
    if stem_name not in stems:
        raise ValueError(f"stem_name={stem_name} not in stems={stems}")
    si = stems.index(stem_name)

    est = overlap_add_separate(model, tokenizer, spec, mix, si, len(stems), window_sec, hop_sec, device=device)
    est = est.unsqueeze(0)  # (1, T)
    torchaudio.save(out_wav_path, est.cpu(), tokenizer.sr)
    return out_wav_path


def run_demo():
    print("Demo placeholder: load a checkpoint with init_model(), then run_separation().")


In [93]:
def parse_args(argv=None):
    p = argparse.ArgumentParser()
    p.add_argument("--dataset_root", type=str, required=True)
    p.add_argument("--out_dir", type=str, default="./runs")
    p.add_argument("--ckpt", type=str, default="")
    p.add_argument("--device", type=str, default="")

    p.add_argument("--sample_rate", type=int, default=44100)
    p.add_argument("--window_sec", type=float, default=0.1)
    p.add_argument("--n_fft", type=int, default=256)
    p.add_argument("--hop", type=int, default=256)
    p.add_argument("--n_mels", type=int, default=64)

    p.add_argument("--d_model", type=int, default=512)
    p.add_argument("--n_layers", type=int, default=8)
    p.add_argument("--n_heads", type=int, default=8)
    p.add_argument("--dropout", type=float, default=0.1)
    p.add_argument("--max_seq_len", type=int, default=4096)

    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--grad_accum", type=int, default=8)
    p.add_argument("--lr", type=float, default=3e-4)
    p.add_argument("--min_lr", type=float, default=3e-5)
    p.add_argument("--weight_decay", type=float, default=0.05)
    p.add_argument("--warmup_steps", type=int, default=2000)
    p.add_argument("--total_steps", type=int, default=150_000)
    p.add_argument("--val_every", type=int, default=1000)
    p.add_argument("--ckpt_every", type=int, default=5000)
    p.add_argument("--sisdr_eval_every", type=int, default=5000)
    p.add_argument("--mix_from_stems", action="store_true")

    p.add_argument("--metrics_dir", type=str, default="")
    p.add_argument("--seed", type=int, default=1337)
    p.add_argument("--num_workers", type=int, default=2)

    p.add_argument("cmd", choices=["train", "eval_windows", "eval_tracks"])
    p.add_argument("--eval_out", type=str, default="")
    p.add_argument("--eval_tracks_max", type=int, default=0)
    p.add_argument("--ola_hop_sec", type=float, default=0.05)

    return p.parse_args(argv)

In [94]:
def main(argv=None):
    args = parse_args(argv)

    device = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
    cfg = TrainConfig(
        dataset_root=args.dataset_root,
        out_dir=args.out_dir,
        sample_rate=args.sample_rate,
        window_sec=args.window_sec,
        n_fft=args.n_fft,
        hop=args.hop,
        n_mels=args.n_mels,
        d_model=args.d_model,
        n_layers=args.n_layers,
        n_heads=args.n_heads,
        dropout=args.dropout,
        max_seq_len=args.max_seq_len,
        batch_size=args.batch_size,
        grad_accum=args.grad_accum,
        lr=args.lr,
        min_lr=args.min_lr,
        weight_decay=args.weight_decay,
        warmup_steps=args.warmup_steps,
        total_steps=args.total_steps,
        val_every=args.val_every,
        ckpt_every=args.ckpt_every,
        sisdr_eval_every=args.sisdr_eval_every,
        mix_from_stems=args.mix_from_stems,
        metrics_dir=args.metrics_dir if args.metrics_dir else None,
        seed=args.seed,
        num_workers=args.num_workers,
        device=device,
    )

    if args.cmd == "train":
        train_model(cfg)
        return

    # For eval, load checkpoint
    spec = TokenSpec()
    stems_path = os.path.join(cfg.out_dir, "stems.json")
    if not os.path.exists(stems_path):
        # fallback: infer stems from dataset
        ds_tmp = StemSeparationDataset(cfg.dataset_root, sample_rate=cfg.sample_rate, window_sec=cfg.window_sec,
                                       mix_from_stems=cfg.mix_from_stems, seed=cfg.seed)
        stems = ds_tmp.stem_names
    else:
        stems = json.load(open(stems_path, "r", encoding="utf-8"))

    n_stems = len(stems)
    tokenizer = AudioTokenizer(cfg.sample_rate, cfg.n_fft, cfg.hop, cfg.n_mels, token_spec=spec, device=device)
    model = GPTSeparator(
        vocab_size=spec.total_vocab(n_stems),
        d_model=cfg.d_model, n_layers=cfg.n_layers, n_heads=cfg.n_heads,
        dropout=cfg.dropout, max_seq_len=cfg.max_seq_len
    ).to(device)

    ckpt_path = args.ckpt if args.ckpt else os.path.join(cfg.out_dir, "checkpoints", "ckpt_latest.pt")
    if not os.path.exists(ckpt_path):
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
    load_ckpt(ckpt_path, model, None)
    model.eval()

    if args.cmd == "eval_windows":
        ds = StemSeparationDataset(cfg.dataset_root, sample_rate=cfg.sample_rate, window_sec=cfg.window_sec,
                                   mix_from_stems=cfg.mix_from_stems, seed=cfg.seed + 1)
        out_csv = args.eval_out or os.path.join(cfg.metrics_dir or os.path.join(cfg.out_dir, "metrics"), "sisdr_windows.csv")
        evaluate_sisdr_windows(model, ds, tokenizer, spec, stems, n_windows=256, out_csv=out_csv, device=device)
        print(f"Wrote: {out_csv}")
        return

    if args.cmd == "eval_tracks":
        out_dir = args.eval_out or (cfg.metrics_dir or os.path.join(cfg.out_dir, "metrics"))
        max_tracks = args.eval_tracks_max if args.eval_tracks_max > 0 else None
        per_stem_csv, per_track_csv = evaluate_sisdr_tracks(
            model, cfg.dataset_root, tokenizer, spec, stems,
            out_dir=out_dir, device=device,
            window_sec=cfg.window_sec, hop_sec=args.ola_hop_sec, max_tracks=max_tracks
        )
        print(f"Wrote: {per_stem_csv}")
        print(f"Wrote: {per_track_csv}")
        return

In [86]:
import torch.multiprocessing
torch.multiprocessing.set_start_method('spawn', force=True)

# Retry training after fixing pin_memory
main(argv=["--dataset_root", "/content/dataset_root", "--out_dir", "./runs/exp_notebook_train", "train", "--window_sec", "0.01", "--n_mels", "16", "--num_workers", "2", "--batch_size", "64"])

Device=cuda | stems=['bass.wav', 'drums.wav', 'other.wav', 'vocals.wav'] | effective_batch=64
step 1000/150000 | lr 1.50e-04 | train_ce 1.3966 | val_ce 0.9076 | time 4.2h


KeyboardInterrupt: 