In [7]:
# sep_gpt_v0.py
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import csv
import os
from torchaudio.transforms import MelSpectrogram
from typing import List, Dict, Tuple, Optional


# Set the multiprocessing start method to 'spawn' for CUDA compatibility
# This needs to be called once at the beginning of your script/notebook
# and before any multiprocessing (like DataLoader with num_workers > 0) is started.
try:
    torch.multiprocessing.set_start_method('spawn', force=True)
    print("[info] Multiprocessing start method set to 'spawn'.")
except RuntimeError as e:
    print(f"[warn] Could not set multiprocessing start method: {e}. It might already be set.")

# =========================
# Config (minimal recipe)
# =========================
class Cfg:
    sr = 44100
    # === STFT (Option A) ===
    # Smaller STFT to reduce token count so target fits under max_seq_len
    n_fft = 256
    hop = 256             # no overlap in STFT itself (we do overlap at windowing stage)
    win = 256
    pad = 0
    center = True
    power = 1.0           # magnitude (not power) for STFT

    # === Mel reducer ===
    use_mel = True
    n_mels = 64
    mel_fmin = 0.0
    mel_fmax = None  # defaults to sr/2 if None

    # Tokenization range on log10 magnitude (clipped)
    logmag_min_db = -12.0  # ~ -12 dB relative floor after per-window normalize
    logmag_max_db = 2.0
    n_bits = 8             # 8-bit uniform quant

    # Model
    d_model = 512
    n_heads = 8
    n_layers = 8
    ff_mult = 4
    dropout = 0.1

    # Windows
    seconds = 0.1   # short inference window to keep sequence under max_seq_len
    max_seq_len = 4096

    # Stems (dataset uses subfolders: vocals/, drums/, bass/, guitar/)
    stems = ["VOCALS", "DRUMS", "BASS", "GUITAR"]

    device = "cuda" if torch.cuda.is_available() else "cpu"

cfg = Cfg()

# =========================
# Helper: STFT / iSTFT
# =========================
def stft_mag_phase(wav: torch.Tensor):
    """
    wav: (T,) mono tensor, float32 [-1,1]
    returns: mag (frames, freq), phase (frames, freq)
    """
    # Use a rectangular window (ones) with hop=win for valid overlap-add
    # This is a workaround for "window overlap add min: 1" error with Hann window
    window = torch.ones(cfg.win, device=wav.device) # Changed from torch.hann_window
    stft = torch.stft(
        wav,
        n_fft=cfg.n_fft,
        hop_length=cfg.hop,
        win_length=cfg.win,
        window=window,
        center=cfg.center,
        return_complex=True,
        pad_mode="reflect",
    )  # (freq, frames)
    spec = stft.transpose(0, 1)  # (frames, freq)
    mag = spec.abs()
    phase = torch.angle(spec)
    return mag, phase

def istft_from_mag_phase(mag: torch.Tensor, phase: torch.Tensor, length: int):
    """
    mag/phase: (frames, freq)
    length: samples to trim/pad to
    """
    spec = (mag * torch.exp(1j * phase)).transpose(0, 1)  # (freq, frames)
    # Use a rectangular window (ones) consistent with stft_mag_phase
    window = torch.ones(cfg.win, device=mag.device) # Changed from torch.hann_window
    wav = torch.istft(
        spec,
        n_fft=cfg.n_fft,
        hop_length=cfg.hop,
        win_length=cfg.win,
        window=window,
        center=cfg.center,
        length=length,
    )
    return wav

# =========================
# Quantization (log-mag)
# =========================
def to_logmag(mag: torch.Tensor, eps=1e-8):
    # Per-window normalize: scale by max to stabilize, then log10
    m = torch.clamp(mag / (mag.max() + eps), min=eps)
    logm = torch.log10(m)
    # clamp to configured range
    logm = torch.clamp(
        logm,
        math.log10(10 ** (cfg.logmag_min_db / 20)),
        math.log10(10 ** (cfg.logmag_max_db / 20)),
    )
    return logm

def quantize_logmag(logm: torch.Tensor):
    """
    logm: (frames, freq_or_mel)
    returns: tokens (frames*feat,)
    """
    lo = math.log10(10 ** (cfg.logmag_min_db / 20))
    hi = math.log10(10 ** (cfg.logmag_max_db / 20))
    qlevels = (1 << cfg.n_bits)
    x = (logm - lo) / (hi - lo)  # [0,1]
    x = torch.clamp(x, 0, 1)
    q = torch.round(x * (qlevels - 1)).to(torch.long)  # [0..255]
    return q.view(-1)

def dequantize_to_mag(tokens: torch.Tensor, shape):
    """
    tokens: (frames*feat,)
    shape: (frames, feat)
    returns: magnitude-like features (frames, feat)
    """
    lo = math.log10(10 ** (cfg.logmag_min_db / 20))
    hi = math.log10(10 ** (cfg.logmag_max_db / 20))
    qlevels = (1 << cfg.n_bits)
    x = tokens.float() / (qlevels - 1)
    logm = x * (hi - lo) + lo
    m = 10 ** logm  # invert log10 amplitude
    return m.view(*shape)

# =========================
# Mel reducer (linear <-> mel), deterministic DSP (no NN)
# =========================
_mel_cache = {}
def _get_mel_mats(device):
    """
    Returns (fb, pinv) where:
      fb:   (F, M) linear->mel filterbank
      pinv: (M, F) pseudo-inverse for mel->linear
    Cached per (n_fft, sr, n_mels, device).
    """
    key = (cfg.n_fft, cfg.sr, cfg.n_mels, device)
    if key in _mel_cache:
        return _mel_cache[key]

    n_freqs = cfg.n_fft // 2 + 1
    n_mels = cfg.n_mels

    # Use MelSpectrogram to get the filterbank
    mel_spectrogram = MelSpectrogram(
        sample_rate=cfg.sr,
        n_fft=cfg.n_fft,
        n_mels=cfg.n_mels,
        f_min=cfg.mel_fmin,
        f_max=cfg.mel_fmax,
        hop_length=cfg.hop,
        win_length=cfg.win,
        center=cfg.center,
        pad=cfg.pad,
        power=cfg.power, # Note: MelSpectrogram typically uses power=2.0, but our STFT is power=1.0.
                         # We'll use the filterbank from MelSpectrogram and apply it to our magnitude STFT.
        norm='slaney' if hasattr(MelSpectrogram, '_get_slaney_mel_scale') else None, # Use Slaney norm if available
        mel_scale="htk"
    ).to(device)

    # Get the filterbank matrix from the MelSpectrogram object
    # The standard shape is (n_mels, n_freqs) = (M, F)
    # Accessing it directly via mel_scale.fb is typical in recent torchaudio versions
    try:
        fb_raw = mel_spectrogram.mel_scale.fb
    except AttributeError:
         raise RuntimeError("Could not access mel filterbank from MelSpectrogram object. Ensure torchaudio version is compatible.")


    # Ensure the filterbank has the shape (F, M) = (n_freqs, n_mels) for linear_to_mel
    if fb_raw.shape == (n_mels, n_freqs):
        fb = fb_raw.transpose(0, 1)  # Transpose from (M, F) to (F, M)
    elif fb_raw.shape == (n_freqs, n_mels):
        fb = fb_raw  # It's already in (F, M) shape
    else:
        raise RuntimeError(f"Unexpected mel filterbank shape: {fb_raw.shape}. Expected ({n_mels}, {n_freqs}) or ({n_freqs}, {n_mels})")


    # Compute the pseudo-inverse for mel_to_linear
    pinv = torch.pinverse(fb).to(device)  # pinv will be (M, F) if fb is (F, M)

    _mel_cache[key] = (fb, pinv)
    return _mel_cache[key]

def linear_to_mel(mag_linear: torch.Tensor):
    """
    mag_linear: (frames, F)
    returns mel magnitudes: (frames, M)
    """
    fb, _ = _get_mel_mats(mag_linear.device)
    # The error occurs here: mag_linear (frames, F) @ fb (F, M)
    # Error message showed (18x129) @ (64x129)
    # This implies fb is (64, 129). It should be (129, 64).
    # The logic in _get_mel_mats should now ensure fb is (129, 64)
    # so this multiplication becomes (18x129) @ (129x64) resulting in (18x64)
    return mag_linear @ fb

def mel_to_linear(mag_mel: torch.Tensor):
    _, pinv = _get_mel_mats(mag_mel.device)
    # pinv should be (M, F) = (64, 129)
    # mag_mel is (frames, M) = (frames, 64)
    # Multiplication: (frames, 64) @ (64, 129) -> (frames, 129)
    return torch.clamp(mag_mel @ pinv, min=0.0)

# =========================
# Token vocab & packing
# =========================
class Vocab:
    # quantized bins 0..255
    # special/control tokens appended after
    PAD = 256
    BOS = 257
    EOS = 258
    SEP = 259
    MIX = 260
    STEM_BASE = 300  # STEM tokens at STEM_BASE + idx

    def __init__(self):
        self.num_bins = 1 << cfg.n_bits
        self.stem_to_id = {name: self.STEM_BASE + i for i, name in enumerate(cfg.stems)}
        self.id_to_stem = {v: k for k, v in self.stem_to_id.items()}
        self.vocab_size = self.STEM_BASE + len(cfg.stems)

vocab = Vocab()

def pack_sequence(mix_tokens: torch.Tensor, stem_name: str, target_tokens: torch.Tensor):
    """
    Create LM sequence:
    [BOS, MIX, mix..., SEP, STEM(token), target..., EOS]
    Returns tokens (L,), loss_mask (L,) where mask=1 for target positions only.
    """
    stem_tok = torch.tensor([vocab.stem_to_id[stem_name]], device=mix_tokens.device, dtype=torch.long)
    seq = torch.cat([
        torch.tensor([vocab.BOS, vocab.MIX], device=mix_tokens.device),
        mix_tokens,
        torch.tensor([vocab.SEP], device=mix_tokens.device),
        stem_tok,
        target_tokens,
        torch.tensor([vocab.EOS], device=mix_tokens.device),
    ])
    # Loss mask: predict from position after STEM token inclusive
    start = (2 + mix_tokens.numel() + 1 + 1)  # after BOS,MIX + mix + SEP + STEM
    loss_mask = torch.zeros_like(seq, dtype=torch.bool)
    # positions where target tokens + EOS reside
    loss_mask[start:] = True
    return seq, loss_mask

def split_mix_target_from_seq(seq: torch.Tensor):
    # helper for debugging; not needed in training loop
    pass

# =========================
# GPT-style decoder-only model
# =========================
class GPTBlock(nn.Module):
    def __init__(self, d_model, n_heads, ff_mult, dropout):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ln2 = nn.LayerNorm(d_model)
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_mult * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_mult * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x, attn_mask):
        h = self.ln1(x)
        # attn_mask may be (L,L) or (B*n_heads, L, L). We pass (B*n_heads, L, L)
        y, _ = self.attn(h, h, h, attn_mask=attn_mask, need_weights=False)
        x = x + y
        h = self.ln2(x)
        x = x + self.ff(h)
        return x

class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, n_layers, n_heads, ff_mult, dropout):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        # Use cfg.max_seq_len for positional embedding
        self.pos = nn.Parameter(torch.zeros(1, cfg.max_seq_len, d_model))
        self.blocks = nn.ModuleList([
            GPTBlock(d_model, n_heads, ff_mult, dropout) for _ in range(n_layers)
        ])
        self.ln = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
        self.n_heads = n_heads  # Store n_heads for mask creation

    def forward(self, tokens, loss_mask=None):
        """
        tokens: (B, L) long
        loss_mask: (B, L) bool (optional)
        returns: logits (B, L, V), loss (scalar) if mask provided
        """
        B, L = tokens.shape
        max_pos = self.pos.size(1)
        if L > max_pos:
            tokens = tokens[:, :max_pos]
            if loss_mask is not None:
                loss_mask = loss_mask[:, :max_pos]
            L = max_pos
        x = self.embed(tokens) + self.pos[:, :L, :]
        # causal mask: no lookahead
        causal = torch.triu(torch.ones(L, L, device=tokens.device, dtype=torch.bool), diagonal=1)
        # Repeat causal mask for each batch*head (supported by MultiheadAttention)
        causal = causal.unsqueeze(0).repeat(B * self.n_heads, 1, 1)

        for blk in self.blocks:
            x = blk(x, causal)
        x = self.ln(x)
        logits = self.out(x)
        loss = None
        if loss_mask is not None:
            # shift targets by one (next token prediction)
            logits_shift = logits[:, :-1, :].contiguous()
            targets = tokens[:, 1:].contiguous()
            mask = loss_mask[:, 1:].contiguous()
            if mask.view(-1).sum() > 0:
                loss = F.cross_entropy(
                    logits_shift.view(-1, logits.size(-1))[mask.view(-1)],
                    targets.view(-1)[mask.view(-1)]
                )
            else:
                loss = torch.tensor(0.0, device=logits.device)
        return logits, loss

# =========================
# Data demo utilities
# =========================
def pad_or_trim(wav, length):
    T = wav.numel()
    if T >= length:
        return wav[:length]
    out = torch.zeros(length, device=wav.device)
    out[:T] = wav
    return out

def mono(wav):
    if wav.dim() == 2:
        return wav.mean(0)
    return wav

def load_wav(path, target_sr):
    wav, sr = torchaudio.load(path)
    wav = mono(wav)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    wav = wav.clamp(-1, 1)
    return wav

# =========================
# End-to-end: one batch example
# =========================
def make_batch(mix_wav, stem_wav, stem_name: str):
    """
    Build a (1, L) token batch and loss mask for one window.
    """
    device = cfg.device
    # STFTs
    mix_mag, mix_phase = stft_mag_phase(mix_wav)
    stem_mag, _ = stft_mag_phase(stem_wav)

    # Tokenize in MEL space (64 bands) if enabled
    if cfg.use_mel:
        mix_feat = linear_to_mel(mix_mag)
        stem_feat = linear_to_mel(stem_mag)
    else:
        mix_feat, stem_feat = mix_mag, stem_mag
    mix_log = to_logmag(mix_feat)
    stem_log = to_logmag(stem_feat)
    mix_tok = quantize_logmag(mix_log)
    stem_tok = quantize_logmag(stem_log)

    # Pack
    seq, mask = pack_sequence(mix_tok, stem_name, stem_tok)

    # Trim to max_seq_len
    if seq.numel() > cfg.max_seq_len:
        seq = seq[:cfg.max_seq_len]
        mask = mask[:cfg.max_seq_len]

    feat_shape = mix_feat.shape  # (frames, M or F) used to reshape target tokens
    return seq.unsqueeze(0).to(device), mask.unsqueeze(0).to(device), mix_phase, feat_shape

@torch.no_grad()
def greedy_decode(model, mix_tokens, stem_name: str, max_new=None):
    """
    mix_tokens: (Lmix,) long
    Returns full sequence including prefix+generated tokens.
    """
    device = cfg.device
    stem_tok = torch.tensor([vocab.stem_to_id[stem_name]], device=device, dtype=torch.long)
    seq = torch.cat([
        torch.tensor([vocab.BOS, vocab.MIX], device=device),
        mix_tokens.to(device),
        torch.tensor([vocab.SEP], device=device),
        stem_tok
    ])
    seq = seq.unsqueeze(0)  # (1, L)
    # generate until EOS or max_new
    max_new = max_new or (mix_tokens.numel() + 1024)  # cap
    # Cap generation length at max_seq_len
    max_new = min(max_new, cfg.max_seq_len - seq.size(1))
    for _ in range(max_new):
        logits, _ = model(seq)
        next_logits = logits[:, -1, :]
        next_tok = torch.argmax(next_logits, dim=-1)  # shape: (1,)
        seq = torch.cat([seq, next_tok.unsqueeze(1)], dim=1)
        if int(next_tok.item()) == vocab.EOS:
            break
    return seq.squeeze(0)

def reconstruct_from_sequence(seq, mix_phase, feat_shape, target_len):
    """
    Extract predicted target tokens from seq and reconstruct WAV via mixture phase.
    """
    # Locate SEP index
    sep_pos = (seq == vocab.SEP).nonzero(as_tuple=True)[0]
    if sep_pos.numel() == 0:
        print("Warning: SEP token not found in sequence. Reconstruction may be incomplete.")
        # If SEP is not found, assume the entire sequence after the prefix is target tokens
        # (2 for BOS, MIX + mix_tokens length)
        prefix_len = 2 + (feat_shape[0] * feat_shape[1])  # Approximate mix token length from feature shape
        start = prefix_len + 1 + 1  # after BOS, MIX, mix_tokens, SEP, STEM
        start = min(start, seq.numel())  # Ensure start is within bounds
        predicted_tokens_segment = seq[start:]
    else:
        sep_idx = sep_pos[-1].item()
        start = sep_idx + 2  # skip SEP and STEM token
        # Collect predicted tokens until EOS or end of sequence
        eos_pos = (seq[start:] == vocab.EOS).nonzero(as_tuple=True)[0]
        if eos_pos.numel() > 0:
            end = start + int(eos_pos[0].item())
        else:
            end = len(seq)
        predicted_tokens_segment = seq[start:end]

    # Determine the expected length of the target token sequence (mel or linear feature length)
    expected_len = feat_shape[0] * feat_shape[1]

    # Explicitly pad or trim the predicted tokens to the expected length
    if predicted_tokens_segment.numel() < expected_len:
        pad_len = expected_len - predicted_tokens_segment.numel()
        # Pad with a valid quantized bin (e.g., 0 for lowest magnitude)
        tgt_tokens_padded = torch.cat(
            [predicted_tokens_segment, torch.full((pad_len,), 0, device=predicted_tokens_segment.device, dtype=torch.long)]
        )
    else:
        tgt_tokens_padded = predicted_tokens_segment[:expected_len]  # Trim if longer

    # Dequantize to feature magnitude (mel or linear)
    pred_feat_mag = dequantize_to_mag(tgt_tokens_padded.detach(), feat_shape).to(mix_phase.device)
    # If using mel, convert back to linear frequency bins before iSTFT
    if cfg.use_mel:
        pred_mag_linear = mel_to_linear(pred_feat_mag)
    else:
        pred_mag_linear = pred_feat_mag

    # Use mixture phase for reconstruction
    wav = istft_from_mag_phase(pred_mag_linear, mix_phase, target_len)
    return wav.clamp(-1, 1)

def demo_step(model, optimizer, mix_wav, stem_wav, stem_name="VOCALS"):
    """
    Runs one supervised step and a greedy decode.
    """
    # Build batch
    seq, mask, mix_phase, feat_shape = make_batch(mix_wav, stem_wav, stem_name)
    # Visibility into how many target tokens are being trained this step
    mask_tokens = int(mask.sum().item())
    print(f"[debug] loss_mask target tokens this step: {mask_tokens}")
    # assert mask_tokens > 0, "Loss mask is empty—target tokens were entirely truncated."

    model.train()
    logits, loss = model(seq, loss_mask=mask)
    optimizer.zero_grad()
    if loss is not None and loss.requires_grad:
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    # Greedy decode (eval)
    model.eval()
    # The mix_tokens slice corresponds to tokens between MIX and SEP
    sep_pos = (seq[0] == vocab.SEP).nonzero(as_tuple=True)[0]
    if sep_pos.numel() > 0:
        mix_tok_end = sep_pos[0].item()
    else:
        mix_tok_end = 2 + (feat_shape[0] * feat_shape[1])
        mix_tok_end = min(mix_tok_end, seq.size(1))
    mix_tok = seq[0, 2:mix_tok_end]

    with torch.no_grad():
        full = greedy_decode(model, mix_tok, stem_name=stem_name)
    # Reconstruct predicted stem
    target_len = mix_wav.numel()
    pred_wav = reconstruct_from_sequence(full, mix_phase, feat_shape, target_len)
    return float(loss.item()) if loss is not None else float('nan'), pred_wav

# =========================
# Inference: windowed separation with overlap–add
# =========================
def build_model():
    return GPT(
        vocab_size=vocab.vocab_size,
        d_model=cfg.d_model,
        n_layers=cfg.n_layers,
        n_heads=cfg.n_heads,
        ff_mult=cfg.ff_mult,
        dropout=cfg.dropout
    ).to(cfg.device)

def load_checkpoint(model: nn.Module, ckpt_path: str):
    if ckpt_path is None:
        print("[warn] No checkpoint provided. Using randomly initialized weights (predictions will sound like noise).")
        return
    state = torch.load(ckpt_path, map_location=cfg.device)
    if isinstance(state, dict) and "state_dict" in state:
        state = state["state_dict"]
    model.load_state_dict(state, strict=False)
    print(f"[info] Loaded checkpoint from: {ckpt_path}")

def separate_stems_from_mix(model: nn.Module, mix_wav: torch.Tensor, stems=None, overlap: float = 0.5):
    """
    mix_wav: (T,) mono on cfg.device
    stems: list of stem names to generate (defaults to cfg.stems)
    overlap: fraction in [0,1). 0.5 -> 50% overlap.
    returns dict {stem_name: (T,) tensor}
    """
    stems = stems or cfg.stems
    device = cfg.device
    T_total = mix_wav.numel()
    T_win = int(cfg.seconds * cfg.sr)
    assert T_win > 0, "Window size (seconds * sr) must be > 0"
    hop_win = max(1, int(T_win * (1.0 - overlap)))
    print(f"[info] Separation with window={T_win} samples ({cfg.seconds:.3f}s), hop={hop_win} samples, overlap={overlap:.2f}")

    # Hann for overlap–add (fade)
    fade = torch.hann_window(T_win, periodic=False, device=device)

    # Outputs and weight accumulator
    outs = {stem: torch.zeros(T_total, device=device) for stem in stems}
    weight = torch.zeros(T_total, device=device)

    pos = 0
    while pos < T_total:
        end = min(pos + T_win, T_total)
        # Extract and pad this chunk to T_win
        chunk = pad_or_trim(mix_wav[pos:end], T_win)
        # STFT -> (optional) MEL -> tokens
        mix_mag, mix_phase = stft_mag_phase(chunk)
        if cfg.use_mel:
            mix_feat = linear_to_mel(mix_mag)
            feat_shape = mix_feat.shape
        else:
            mix_feat = mix_mag
            feat_shape = mix_mag.shape
        mix_log = to_logmag(mix_feat)
        mix_tok = quantize_logmag(mix_log)

        # Decode each requested stem
        for stem in stems:
            with torch.no_grad():
                full_seq = greedy_decode(model, mix_tok, stem_name=stem)
                pred_chunk = reconstruct_from_sequence(full_seq, mix_phase, feat_shape, T_win)  # (T_win,)

            # Overlap–add with fade
            L = end - pos  # actual (possibly shorter) region
            outs[stem][pos:end] += (pred_chunk[:L] * fade[:L])
        weight[pos:end] += fade[:L]

        pos += hop_win

    # Normalize by accumulated weights to avoid gain on overlaps
    eps = 1e-8
    for stem in stems:
        outs[stem] = (outs[stem] / (weight + eps)).clamp(-1, 1)
    return outs

def separate_file(model: nn.Module, mix_path: str, out_dir: str = ".", stems: list[str] | None = None,
                   overlap: float = 0.5, ckpt: str | None = None):
    os.makedirs(out_dir, exist_ok=True)
    wav = load_wav(mix_path, cfg.sr).to(cfg.device)
    wav = mono(wav)
    outs = separate_stems_from_mix(model, wav, stems=stems, overlap=overlap)
    for stem, audio in outs.items():
        out_path = os.path.join(out_dir, f"{stem.lower()}.wav")
        torchaudio.save(out_path, audio.unsqueeze(0).cpu(), cfg.sr)
        print(f"[write] {out_path}")

# =========================
# Dataset + Training
# =========================
def _default_stem_files() -> Dict[str, str]:
    """
    Map our logical stem names to **subdirectory** names in each track folder.
    For the provided layout we expect, per track GUID:
        <GUID>/vocals/<wav>
        <GUID>/drums/<wav>
        <GUID>/bass/<wav>
        <GUID>/guitar/<wav>
    NOTE: name values here are directory names, not filenames.
    """
    return {
        "VOCALS": "vocals",
        "DRUMS":  "drums",
        "BASS":   "bass",
        "GUITAR": "guitar",
    }

def _safe_lower(s: str) -> str:
    return s.lower() if isinstance(s, str) else s

def _list_tracks(root: str) -> List[str]:
    tracks = [os.path.join(root, d) for d in os.listdir(root)
              if os.path.isdir(os.path.join(root, d))]
    tracks.sort()
    return tracks

def split_dataset(root: str, val_fraction: float = 0.1, test_fraction: float = 0.1, seed: int = 42):
    """Deterministic split into train/val/test lists of track directories."""
    import random
    tracks = _list_tracks(root)
    rng = random.Random(seed)
    rng.shuffle(tracks)
    n = len(tracks); n_test = int(round(n*test_fraction)); n_val = int(round((n-n_test)*val_fraction))
    test_tracks = tracks[:n_test]
    val_tracks  = tracks[n_test:n_test+n_val]
    train_tracks = tracks[n_test+n_val:]
    print(f"[split] total={n} → train={len(train_tracks)}, val={len(val_tracks)}, test={len(test_tracks)}")
    return {"train": train_tracks, "val": val_tracks, "test": test_tracks}

def _find_file_case_insensitive(folder: str, filename: str) -> Optional[str]:
    target = _safe_lower(filename)
    for f in os.listdir(folder):
        if _safe_lower(f) == target:
            return os.path.join(folder, f)
    return None

def _find_stem_wav_from_subdir(track_dir: str, subdir_name: str) -> Optional[str]:
    """
    Locate a WAV file under <track_dir>/<subdir_name>/
    Returns the first lexicographically sorted *.wav (case-insensitive), or None.
    """
    if subdir_name is None:
        return None
    stem_dir = os.path.join(track_dir, subdir_name)
    if not os.path.isdir(stem_dir):
        return None
    candidates = [f for f in os.listdir(stem_dir) if f.lower().endswith(".wav")]
    candidates.sort()
    if not candidates:
        return None
    return os.path.join(stem_dir, candidates[0])

class StemSeparationDataset(torch.utils.data.Dataset):
    """
    Expects a directory of track folders:
      root/<GUID>/
        vocals/<wav>
        drums/<wav>
        bass/<wav>
        guitar/<wav>
    Builds fixed-length windows (cfg.seconds) with a given hop (overlap in time domain).
    Each item returns (mix_window, stem_window, stem_name).
    """
    def __init__(
        self,
        root: str,
        track_dirs: Optional[List[str]] = None,
        stems: Optional[List[str]] = None,
        window_seconds: float = None,
        overlap: float = 0.5,
        stem_files: Optional[Dict[str, str]] = None,
        mix_from_stems: bool = True,
    ):
        super().__init__()
        self.root = root
        self.track_dirs = track_dirs  # if provided, restrict dataset to these tracks
        self.stems = stems or cfg.stems
        self.window_seconds = cfg.seconds if window_seconds is None else float(window_seconds)
        self.overlap = float(overlap)
        assert 0.0 <= self.overlap < 1.0
        # Here 'stem_files' actually maps stem name -> subdirectory name
        self.stem_files = stem_files or _default_stem_files()
        self.mix_from_stems = bool(mix_from_stems)
        self.win_samples = int(self.window_seconds * cfg.sr)
        self.hop_samples = max(1, int(self.win_samples * (1.0 - self.overlap)))

        # Index: list of (track_dir, stem_name, start_sample)
        self.index: List[Tuple[str, str, int]] = []
        self._build_index()

    def _build_index(self):
        tracks = self.track_dirs if self.track_dirs is not None else _list_tracks(self.root)
        for track_dir in tracks:
            # Determine track duration from the longest available stem WAV in its subfolders
            max_frames = 0
            for s in self.stems:
                subdir = self.stem_files.get(s, None)
                sp = _find_stem_wav_from_subdir(track_dir, subdir)
                if sp is None:
                    continue
                try:
                    info = torchaudio.info(sp)
                    frames_at_target = int(round(info.num_frames * (cfg.sr / info.sample_rate)))
                    max_frames = max(max_frames, frames_at_target)
                except Exception as e:
                    print(f"[warn] Could not read stem WAV in {track_dir}/{subdir}: {e}")
                    continue
            if max_frames == 0:
                print(f"[warn] No usable stem audio in {track_dir}; skipping.")
                continue
            # Generate window start positions
            pos = 0
            while pos < max_frames:
                self.index.extend([(track_dir, stem, pos) for stem in self.stems])
                pos += self.hop_samples

    def __len__(self):
        return len(self.index)

    def _load_wav_window(self, path: str, start: int, length: int) -> torch.Tensor:
        # Load a small window; if backend doesn't support frame_offset efficiently, this will still work.
        wav, sr = torchaudio.load(path, frame_offset=max(0, start), num_frames=length)
        wav = mono(wav)
        if sr != cfg.sr:
            wav = torchaudio.functional.resample(wav, sr, cfg.sr)
        # If shorter than length due to end of file, pad
        wav = pad_or_trim(wav, length)
        return wav

    def __getitem__(self, idx: int):
        track_dir, stem_name, start = self.index[idx]
        # Build mixture by summing stems (dataset layout has no mixture.wav)
        mix_win = torch.zeros(self.win_samples, dtype=torch.float32)
        for s in self.stems:
            subdir_s = self.stem_files.get(s, None)
            sp = _find_stem_wav_from_subdir(track_dir, subdir_s)
            if sp is None:
                continue
            mix_win = mix_win + self._load_wav_window(sp, start, self.win_samples)
        # conservative anti-clipping (per-window norm still applied later)
        maxabs = mix_win.abs().max().item()
        if maxabs > 1.0:
            mix_win = (mix_win / maxabs).clamp_(-1, 1)
        # Target stem
        stem_file = self.stem_files.get(stem_name, None)  # actually subdir name
        if stem_file is None:
            raise KeyError(f"No filename mapping for stem '{stem_name}'")
        stem_path = _find_stem_wav_from_subdir(track_dir, stem_file)
        if stem_path is None:
            # If stem missing, assume silence for that stem
            stem_win = torch.zeros_like(mix_win)
        else:
            stem_win = self._load_wav_window(stem_path, start, self.win_samples)
        # Return minimal metadata so we can tag window-level metrics
        return mix_win, stem_win, stem_name, track_dir, start

def collate_pack(batch):
    """
    batch: list of tuples (mix_win, stem_win, stem_name, track_dir, start)
    Returns:
      tokens: (B, L) long
      loss_mask: (B, L) bool
      targets: (B, T) float
      mix_phases: list of phase tensors per item
      feat_shapes: list of feature shapes per item
      wav_lens: list of target lengths per item
      stem_names: list of stem names
      track_names: list[str] base names for tracks
      starts: list[int] window start sample (at dataset SR)
    NOTE: With fixed cfg.seconds and fixed STFT/mel, all sequences L should match.
    """
    tokens_list = []
    masks_list = []
    targets_list = []
    phases = []
    shapes = []
    wav_lens = []
    stems = []
    track_names = []
    starts = []
    for mix_win, stem_win, stem_name, track_dir, start in batch:
        mix_win = mix_win.to(cfg.device)
        stem_win = stem_win.to(cfg.device)
        targets_list.append(stem_win)
        seq, mask, mix_phase, feat_shape = make_batch(mix_win, stem_win, stem_name)
        tokens_list.append(seq[0])  # seq is (1, L)
        masks_list.append(mask[0])
        phases.append(mix_phase)
        shapes.append(feat_shape)
        wav_lens.append(mix_win.numel())
        stems.append(stem_name)
        track_names.append(os.path.basename(track_dir.rstrip("/\\")))
        starts.append(int(start))
    tokens = torch.stack(tokens_list, dim=0)
    loss_mask = torch.stack(masks_list, dim=0)
    targets = torch.stack(targets_list, dim=0)
    return tokens, loss_mask, targets, phases, shapes, wav_lens, stems, track_names, starts

@torch.no_grad()
def decode_batch(model, batch_tokens, phases, shapes, wav_lens, stems):
    """
    Greedy-decodes each item in the batch for its requested stem.
    Returns list of waveforms (on device).
    """
    outs = []
    B = batch_tokens.size(0)
    for i in range(B):
        seq = batch_tokens[i]
        # slice mixture tokens between MIX and SEP
        sep_pos = (seq == vocab.SEP).nonzero(as_tuple=True)[0]
        if sep_pos.numel() > 0:
            mix_tok_end = sep_pos[0].item()
        else:
            mix_tok_end = 2 + (shapes[i][0] * shapes[i][1])
            mix_tok_end = min(mix_tok_end, seq.size(0))
        mix_tok = seq[2:mix_tok_end]
        full = greedy_decode(model, mix_tok, stems[i])
        wav = reconstruct_from_sequence(full, phases[i], shapes[i], wav_lens[i])
        outs.append(wav)
    return outs

def si_sdr(ref: torch.Tensor, est: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    """
    Scale-invariant SDR in dB for 1D tensors on the same device.
    """
    ref = ref - ref.mean()
    est = est - est.mean()
    ref_energy = torch.sum(ref**2) + eps
    scale = torch.sum(ref * est) / ref_energy
    proj = scale * ref
    noise = est - proj
    ratio = (torch.sum(proj**2) + eps) / (torch.sum(noise**2) + eps)
    return 10.0 * torch.log10(ratio + eps)

@torch.no_grad()
def evaluate_sisdr(
    dataset_root: str,
    ckpt: str,
    batch_size: int = 2,
    overlap: float = 0.5,
    num_workers: int = 2,
    max_batches: int = 25,
    track_dirs: Optional[List[str]] = None,
    metrics_dir: Optional[str] = None,
):
    """
    Window-level SI-SDR evaluation with decoding (greedy).
    Returns (avg_sisdr, per_stem_sisdr_dict). Also writes CSV if metrics_dir is provided
    (or defaults to ./metrics).
    Window-level evaluation with decoding (greedy).
      - Computes SI-SDR per window (decoded audio vs target).
      - Computes masked next-token cross-entropy (CE) per window on the same batch.
    Returns (avg_sisdr, per_stem_sisdr_dict). Also writes CSV (sisdr_windows.csv) to metrics_dir
    (or ./metrics if not provided).
    """
    model = build_model()
    load_checkpoint(model, ckpt)
    ds = StemSeparationDataset(dataset_root, track_dirs=track_dirs, stems=cfg.stems,
                               window_seconds=cfg.seconds, overlap=overlap)
    if len(ds) == 0:
        raise RuntimeError(f"No evaluation items found under {dataset_root}")
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                                     collate_fn=collate_pack, pin_memory=False)
    model.eval()
    total = 0.0; count = 0
    per_stem_sum: Dict[str, float] = {s: 0.0 for s in cfg.stems}
    per_stem_cnt: Dict[str, int] = {s: 0 for s in cfg.stems}
    # (track, start, stem, sisdr, ce_loss, num_masked_tokens)
    window_rows: List[Tuple[str,int,str,float,float,int]] = []
    it = 0
    for tokens, loss_mask, targets, phases, shapes, wav_lens, stems, track_names, starts in dl:
        # --- token CE per-item (masked) ---
        with torch.no_grad():
            logits, _ = model(tokens, loss_mask=None)  # we’ll compute per-item CE manually
        # prepare shifted views once
        logits_shift = logits[:, :-1, :].contiguous()
        targets_shift = tokens[:, 1:].contiguous()
        mask_shift = loss_mask[:, 1:].contiguous()

        preds = decode_batch(model, tokens, phases, shapes, wav_lens, stems)
        for i, pred in enumerate(preds):
            tgt = targets[i]  # (T,)
            d = float(si_sdr(tgt, pred).item())
            total += d; count += 1
            per_stem_sum[stems[i]] += d; per_stem_cnt[stems[i]] += 1
            # per-item CE: only masked positions
            mi = mask_shift[i].view(-1)
            nm = int(mi.sum().item())
            if nm > 0:
                li = logits_shift[i].view(-1, logits.size(-1))[mi]
                ti = targets_shift[i].view(-1)[mi]
                ce_i = float(F.cross_entropy(li, ti).item())
            else:
                ce_i = float("nan")
            window_rows.append((track_names[i], int(starts[i]), stems[i], d, ce_i, nm))
        it += 1
        if it >= max_batches:
            break
    avg = total / max(1, count)
    per_stem = {s: (per_stem_sum[s] / max(1, per_stem_cnt[s])) for s in cfg.stems}
    print(f"[SI-SDR] avg over {count} windows: {avg:.3f} dB")
    for s in cfg.stems:
        print(f"  {s:>6}: {per_stem[s]:.3f} dB (n={per_stem_cnt[s]})")

     # ---- CSV export for window-level metrics ----
    metrics_root = metrics_dir or "./metrics"
    os.makedirs(metrics_root, exist_ok=True)
    windows_csv = os.path.join(metrics_root, "sisdr_windows.csv")
    with open(windows_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["track", "start_sample", "stem", "sisdr_db", "ce_loss", "num_masked_tokens"])
        for (tr, st, stem, sisdr_val, ce_val, nmask) in window_rows:
            w.writerow([tr, st, stem,
                        f"{sisdr_val:.3f}",
                        (f"{ce_val:.5f}" if ce_val == ce_val else ""),  # blank if NaN
                        nmask])
    print(f"[csv] wrote {windows_csv}")
    return avg, per_stem

def save_checkpoint(model: nn.Module, out_dir: str, tag: str):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, f"model_{tag}.pt")
    torch.save(model.state_dict(), path)
    print(f"[ckpt] saved: {path}")
    return path

def train_model(
    dataset_root: str,
    out_dir: str = "./checkpoints",
    epochs: int = 5,
    batch_size: int = 4,
    lr: float = 3e-4,
    overlap: float = 0.5,
    num_workers: int = 2,
    preview_every: int = 1,
    train_tracks: Optional[List[str]] = None,
):
    """
    Train the GPT-style separator on a folder-structured dataset.
    - Saves checkpoints to out_dir each epoch.
    - Prints training loss; optionally writes a preview WAV every few epochs.
    """
    device = cfg.device
    print(f"[train] device={device}")
    model = build_model()
    opt = torch.optim.AdamW(model.parameters(), lr=lr, betas=(0.9, 0.95), weight_decay=0.1)

    if train_tracks is None:
        # default: build a split and use the train portion
        split = split_dataset(dataset_root, val_fraction=0.1, test_fraction=0.1, seed=42)
        train_tracks = split["train"]
    ds = StemSeparationDataset(dataset_root, track_dirs=train_tracks, stems=cfg.stems,
                               window_seconds=cfg.seconds, overlap=overlap)
    if len(ds) == 0:
        raise RuntimeError(f"No training items found under {dataset_root}")
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                     collate_fn=collate_pack, pin_memory=False)

    for epoch in range(1, epochs+1):
        model.train()
        running = 0.0
        steps = 0
        for tokens, loss_mask, targets, phases, shapes, wav_lens, stems, _, _ in dl:
            # Forward
            logits, loss = model(tokens, loss_mask=loss_mask)
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            running += float(loss.item())
            steps += 1
            if steps % 50 == 0:
                print(f"[epoch {epoch}] step {steps} loss {running/steps:.4f}")

        avg_loss = running / max(1, steps)
        print(f"[epoch {epoch}] avg training loss: {avg_loss:.4f}")
        save_checkpoint(model, out_dir, tag=f"epoch{epoch:03d}")

        # Optional preview decode on the last seen batch
        if preview_every and (epoch % preview_every == 0):
            model.eval()
            with torch.no_grad():
                preds = decode_batch(model, tokens, phases, shapes, wav_lens, stems)
            # write first item preview
            preview_dir = os.path.join(out_dir, "previews")
            os.makedirs(preview_dir, exist_ok=True)
            torchaudio.save(os.path.join(preview_dir, f"epoch{epoch:03d}_pred.wav"),
                            preds[0].unsqueeze(0).cpu(), cfg.sr)
            print(f"[preview] wrote {os.path.join(preview_dir, f'epoch{epoch:03d}_pred.wav')}")

    print("[train] done")
    return model

def validate_model(
    dataset_root: str,
    ckpt: str,
    batch_size: int = 4,
    overlap: float = 0.5,
    num_workers: int = 2,
    max_batches: int = 50,
    val_tracks: Optional[List[str]] = None,
):
    """
    Quick validation: reports average token CE loss on windows from the dataset.
    (Faster than full audio SDR; you can add SI-SDR later if desired.)
    """
    model = build_model()
    load_checkpoint(model, ckpt)
    if val_tracks is None:
        split = split_dataset(dataset_root, val_fraction=0.1, test_fraction=0.1, seed=42)
        val_tracks = split["val"]
    ds = StemSeparationDataset(dataset_root, track_dirs=val_tracks, stems=cfg.stems,
                               window_seconds=cfg.seconds, overlap=overlap)
    if len(ds) == 0:
        raise RuntimeError(f"No validation items found under {dataset_root}")
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, num_workers=num_workers,
                                     collate_fn=collate_pack, pin_memory=False)
    model.eval()
    total, count = 0.0, 0
    with torch.no_grad():
        for i, (tokens, loss_mask, targets, _, _, _, _, _, _) in enumerate(dl):
            _, loss = model(tokens, loss_mask=loss_mask)
            total += float(loss.item())
            count += 1
            if i+1 >= max_batches:
                break
    avg = total / max(1, count)
    print(f"[val] avg CE loss over {count} batches: {avg:.4f}")
    return avg

@torch.no_grad()
def evaluate_sisdr_tracks(
    dataset_root: str,
    ckpt: str,
    stems: Optional[List[str]] = None,
    overlap: float = 0.5,
    out_dir: Optional[str] = None,
    track_dirs: Optional[List[str]] = None,
    metrics_dir: Optional[str] = None,
):
    """
    Track-level SI-SDR evaluation.
    - Separates each full track's mixture with the trained model.
    - Computes SI-SDR for each requested stem against its reference stem file.
    - Aggregates per-stem and overall averages; returns (avg_all, per_stem_avg, per_track_dict).

    Directory layout per track:
        <track_dir>/
          mixture.wav  (or mix.wav)
          vocals.wav
          drums.wav
          bass.wav
          other.wav

    Args:
      dataset_root: root folder containing track subfolders.
      ckpt: path to model checkpoint (state_dict).
      stems: list of stems to evaluate (defaults to cfg.stems).
      overlap: window overlap for separation (0..1).
      out_dir: if provided, write predicted stems to out_dir/<track_name>/<stem>.wav
      track_dirs: explicit list of track directories to evaluate (defaults to test split).
      metrics_dir: where to write CSV metrics. If None, uses out_dir if given,
                   otherwise writes to a new "./metrics" folder.
                   Produces: sisdr_per_stem.csv, sisdr_per_track.csv
    """
    stems = stems or cfg.stems
    # Build model
    model = build_model()
    load_checkpoint(model, ckpt)

    # Select tracks
    if track_dirs is None:
        split = split_dataset(dataset_root, val_fraction=0.1, test_fraction=0.1, seed=42)
        track_dirs = split["test"]
        if len(track_dirs) == 0:
            # Fall back to 'val' if test empty
            track_dirs = split["val"]
    if len(track_dirs) == 0:
        raise RuntimeError("No tracks available for evaluation.")

    # Accumulators
    per_stem_sum: Dict[str, float] = {s: 0.0 for s in stems}
    per_stem_cnt: Dict[str, int] = {s: 0     for s in stems}
    per_track_scores: Dict[str, Dict[str, float]] = {}
    name_map = _default_stem_files()  # stem -> subdir

    if out_dir is not None:
        os.makedirs(out_dir, exist_ok=True)

    # Iterate tracks
    for track_dir in track_dirs:
        track_name = os.path.basename(track_dir.rstrip("/\\"))
       # Build full-file mixture by summing stems (no mixture.wav in this layout)
        # Determine max length across available stems first
        max_len = 0
        stem_paths = {}
        for s in stems:
            subdir = name_map.get(s, None)
            sp = _find_stem_wav_from_subdir(track_dir, subdir)
            stem_paths[s] = sp
            if sp is None:
                continue
            info = torchaudio.info(sp)
            frames_at_target = int(round(info.num_frames * (cfg.sr / info.sample_rate)))
            max_len = max(max_len, frames_at_target)
        if max_len == 0:
            print(f"[warn] {track_name}: no usable stems, skipping.")
            continue
        mix = torch.zeros(max_len, dtype=torch.float32, device=cfg.device)
        for s in stems:
            sp = stem_paths.get(s, None)
            if sp is None:
                continue
            wav = load_wav(sp, cfg.sr).to(cfg.device)
            wav = mono(wav)
            wav = pad_or_trim(wav, max_len)
            mix += wav
        # Predict stems (full-file)
        preds = separate_stems_from_mix(model, mix, stems=stems, overlap=overlap)

        # Optionally save predictions
        if out_dir is not None:
            tdir = os.path.join(out_dir, track_name)
            os.makedirs(tdir, exist_ok=True)
            for s in stems:
                torchaudio.save(os.path.join(tdir, f"{s.lower()}.wav"),
                                preds[s].unsqueeze(0).cpu(), cfg.sr)

        # Score each requested stem
        track_scores: Dict[str, float] = {}
        for s in stems:
            ref_dir = name_map.get(s, None)
            if ref_dir is None:
                print(f"[warn] {track_name}: no filename mapping for stem '{s}', skipping.")
                continue
            ref_path = _find_stem_wav_from_subdir(track_dir, ref_dir)
            if ref_path is None:
                print(f"[warn] {track_name}: missing reference subdir '{ref_dir}' for stem '{s}', skipping.")
                continue
            ref = load_wav(ref_path, cfg.sr).to(cfg.device)
            ref = mono(ref)
            # Align lengths to mixture length
            T = mix.numel()
            ref = pad_or_trim(ref, T)
            pred = preds[s][:T]
            d = float(si_sdr(ref, pred).item())

            track_scores[s] = d
            per_stem_sum[s] += d
            per_stem_cnt[s] += 1

        per_track_scores[track_name] = track_scores
        # Print per-track summary
        if len(track_scores) > 0:
            stem_list = ", ".join([f"{k}:{v:.2f}dB" for k, v in track_scores.items()])
            print(f"[track] {track_name}: {stem_list}")
        else:
            print(f"[track] {track_name}: no scored stems.")

    # Reduce
    total_sum = sum(per_stem_sum.values())
    total_cnt = sum(per_stem_cnt.values())
    avg_all = total_sum / max(1, total_cnt)
    per_stem_avg = {s: (per_stem_sum[s] / max(1, per_stem_cnt[s])) for s in stems}

    print(f"[SI-SDR/tracks] overall avg (all stems, all tracks): {avg_all:.3f} dB")
    for s in stems:
        print(f"  {s:>6}: {per_stem_avg[s]:.3f} dB (n={per_stem_cnt[s]})")
     # -------- CSV export --------
    # Decide metrics directory
    metrics_root = metrics_dir or out_dir or "./metrics"
    os.makedirs(metrics_root, exist_ok=True)

    # 1) Per-stem CSV
    per_stem_csv = os.path.join(metrics_root, "sisdr_per_stem.csv")
    with open(per_stem_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["stem", "avg_sisdr_db", "n_scored"])
        for s in stems:
            w.writerow([s, f"{per_stem_avg[s]:.3f}", per_stem_cnt[s]])
    print(f"[csv] wrote {per_stem_csv}")

    # 2) Per-track CSV (one row per track, columns for each stem)
    per_track_csv = os.path.join(metrics_root, "sisdr_per_track.csv")
    header = ["track"] + stems + ["num_scored_stems", "avg_sisdr_db"]
    with open(per_track_csv, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for track_name, scores in per_track_scores.items():
            vals = []
            present = []
            for s in stems:
                if s in scores:
                    vals.append(f"{scores[s]:.3f}")
                    present.append(scores[s])
                else:
                    vals.append("")  # empty if missing
            n_present = len(present)
            row_avg = sum(present) / n_present if n_present > 0 else float("nan")
            w.writerow([track_name] + vals + [n_present, (f"{row_avg:.3f}" if n_present > 0 else "")])
    print(f"[csv] wrote {per_track_csv}")

    return avg_all, per_stem_avg, per_track_scores

# =========================
# Notebook-friendly entry points (no argparse)
# =========================
def init_model(ckpt: str | None = None):
    """Build the model and optionally load a checkpoint. Returns the model."""
    torch.manual_seed(0)
    device = cfg.device
    print("Device:", device)
    model = build_model()
    load_checkpoint(model, ckpt)
    return model

def run_separation(mix_path: str, out_dir: str = ".", stems: list[str] | None = None,
                   overlap: float = 0.5, ckpt: str | None = None):
    """
    Separate a mixed WAV file into stems and write WAVs to out_dir.
    Jupyter usage:
        model = init_model(ckpt="/path/to/model.pt")
        run_separation("/path/to/mix.wav", out_dir="./out", stems=["VOCALS","DRUMS"], overlap=0.5, ckpt=None)
    If ckpt is provided here, it will be loaded inside (you can also pass a pre-loaded model).
    """
    model = init_model(ckpt) if ckpt is not None else build_model()
    if ckpt is None:
        print("[warn] No checkpoint provided. Outputs will likely sound noisy (untrained weights).")
    stems_arg = stems or cfg.stems
    separate_file(model, mix_path, out_dir, stems=stems_arg, overlap=overlap)

def run_demo(output_dir: str = "/content/drive/MyDrive/MSS_Audio/"):
    """
    Run the synthetic demo (440 Hz sine + noise) and write debugging WAVs.
    Jupyter usage:
        run_demo(output_dir="./demo_out")
    """
    torch.manual_seed(0)
    device = cfg.device
    print("Device:", device)

    # Build a fresh (random) model; demo_step runs one toy optimization step
    model = build_model()

    # Backwards-compatible toy demo (synthetic 440 Hz sine + noise)
    T = int(cfg.seconds * cfg.sr)
    t = torch.linspace(0, cfg.seconds, T, dtype=torch.float32)
    sine = 0.5 * torch.sin(2 * math.pi * 440 * t)
    noise = 0.1 * torch.randn_like(sine)
    vocals = sine
    other = noise
    mix = (vocals + other).to(device)

    # Pad/trim exactly cfg.seconds
    mix = pad_or_trim(mix, T)
    vocals = pad_or_trim(vocals.to(device), T)

    opt = torch.optim.AdamW(model.parameters(), lr=3e-4, betas=(0.9, 0.95), weight_decay=0.1)
    loss, pred = demo_step(model, opt, mix, vocals, stem_name="VOCALS")
    print("Loss:", loss)

    # Save predicted wav for inspection
    path = "/content/drive/MyDrive/MSS_Audio/"
    os.makedirs(path, exist_ok=True)
    torchaudio.save(path + "pred_vocals_demo.wav", pred.unsqueeze(0).cpu(), cfg.sr)
    print("Wrote pred_vocals_demo.wav")

    # === Ground-truth recon sanity checks ===
    # 1) Save the clean target (should be a pure 440 Hz tone)
    torchaudio.save(path + "target_vocals_ground_truth.wav", vocals.unsqueeze(0).cpu(), cfg.sr)
    print("Wrote target_vocals_ground_truth.wav")

    # 2) Quantize→dequantize the stem and reconstruct with STEM PHASE (cleanest)
    with torch.no_grad():
        stem_mag_lin_dbg, stem_phase_dbg = stft_mag_phase(vocals)
        if cfg.use_mel:
            stem_feat_dbg = linear_to_mel(stem_mag_lin_dbg)
        else:
            stem_feat_dbg = stem_mag_lin_dbg
        stem_log_dbg = to_logmag(stem_feat_dbg)
        stem_tok_dbg = quantize_logmag(stem_log_dbg)
        deq_feat_dbg = dequantize_to_mag(stem_tok_dbg, stem_feat_dbg.shape).to(stem_phase_dbg.device)
        if cfg.use_mel:
            deq_mag_lin_dbg = mel_to_linear(deq_feat_dbg)
        else:
            deq_mag_lin_dbg = deq_feat_dbg
        recon_gt_stemphase = istft_from_mag_phase(deq_mag_lin_dbg, stem_phase_dbg, T).clamp(-1, 1)
    torchaudio.save(path + "recon_from_gt_tokens.wav", recon_gt_stemphase.unsqueeze(0).cpu(), cfg.sr)
    print("Wrote recon_from_gt_tokens.wav")

    # 3) Same dequantized magnitude but use MIXTURE PHASE (slightly dirtier)
    with torch.no_grad():
        mix_mag_lin_dbg, mix_phase_dbg = stft_mag_phase(mix)
        recon_gt_mixphase = istft_from_mag_phase(deq_mag_lin_dbg, mix_phase_dbg, T).clamp(-1, 1)
    torchaudio.save(path + "recon_from_gt_tokens_mixphase.wav", recon_gt_mixphase.unsqueeze(0).cpu(), cfg.sr)
    print("Wrote recon_from_gt_tokens_mixphase.wav")

[info] Multiprocessing start method set to 'spawn'.


In [None]:
path = "/content/drive/MyDrive/MSS_Audio/moisesdb"
# 0) (Optional) make and reuse a split
split = split_dataset(path + "/dataset_root", val_fraction=0.1, test_fraction=0.1, seed=42)

# 1) Train on the train split
train_model(
    dataset_root=path + "/dataset_root",
    out_dir=path + "/checkpoints",
    epochs=5,
    batch_size=4,
    overlap=0.5,
    train_tracks=split["train"],   # or omit to auto-split
)

# 2) Quick validation (token CE)
validate_model(
    dataset_root=path + "/dataset_root",
    ckpt=path + "/checkpoints/model_epoch005.pt",
    batch_size=4,
    val_tracks=split["val"],       # or omit to auto-split
)

# 3) SI-SDR evaluation with decoding (heavier but more meaningful)
evaluate_sisdr(
    dataset_root=path + "/dataset_root",
    ckpt=path + "/checkpoints/model_epoch005.pt",
    batch_size=2,
    max_batches=100,
    track_dirs=split_dataset(path + "/dataset_root")["val"],
    metrics_dir="./metrics_val"
)

# 4) Inference on a full mix file
model = init_model(ckpt=path + "/checkpoints/model_epoch005.pt")
separate_file(model, mix_path="./some_mix.wav", out_dir="./stems_out",
                  stems=["VOCALS","DRUMS","BASS","OTHER"], overlap=0.5)

# Evaluate on your test split, save predicted stems and CSVs
avg_all, per_stem, per_track = evaluate_sisdr_tracks(
    dataset_root=path + "/dataset_root",
    ckpt=path + "/checkpoints/model_epoch005.pt",
    stems=["VOCALS","DRUMS","BASS","GUITAR"],
    overlap=0.5,
    out_dir="./preds_test",        # optional: writes audio here
    metrics_dir="./metrics_test",  # CSVs will be written here
    track_dirs=split_dataset(path + "/dataset_root")["test"]
)

[split] total=10 → train=8, val=1, test=1
[train] device=cuda
