In [None]:
'''
This notebook is designed to run in Kaggle. This is done 
to offload handling/storage of the 100+GB dataset. 
The MAESTRO V2 dataset should be added as input to the notebook. 
It can be found at:
from https://www.kaggle.com/datasets/jackvial/themaestrodatasetv2/data
'''

In [None]:
!pip install pretty_midi

In [None]:
# Imports
import numpy as np
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import matplotlib.pyplot as plt
from data.dataset import MaestroDataset
from torch.utils.data import DataLoader
import librosa.display
import torch
from torch.utils.data import Dataset
import pandas as pd
import librosa
import pretty_midi
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# MaestroASTDataset
class MaestroASTDataset(Dataset):
    """
    For AST training: returns waveform + token_ids (LongTensor)
    Tokenization is done from the MIDI file directly (bypasses piano-roll).
    """

    def __init__(
        self,
        root_dir: str,
        tokenizer,                 # EventMIDITokenizer instance
        csv_path: str | None = None,
        year=None,
        split: str = "train",
        sr: int = 16000,
        subset_size: int | None = None,
        max_token_len: int = 256,
        return_midi_path: bool = False,   # <-- add this
    ):
        self.root_dir = root_dir
        self.tokenizer = tokenizer
        self.sr = sr
        self.max_token_len = max_token_len
        self.return_midi_path = return_midi_path

        if csv_path is None:
            csv_path = os.path.join(root_dir, "maestro-v2.0.0.csv")
        df = pd.read_csv(csv_path)

        if year is not None:
            if isinstance(year, str) and "," in year:
                years = [int(y.strip()) for y in year.split(",") if y.strip()]
            elif isinstance(year, (list, tuple, set)):
                years = [int(y) for y in year]
            else:
                years = [int(year)]
            df = df[df["year"].isin(years)]

        if split is not None:
            df = df[df["split"] == split]

        if subset_size:
            df = df.head(subset_size)

        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        audio_path = os.path.join(self.root_dir, row["audio_filename"])
        midi_path  = os.path.join(self.root_dir, row["midi_filename"])

        if not os.path.exists(audio_path) and audio_path.endswith(".wav"):
            audio_path = audio_path.replace(".wav", ".mp3")

        y, _ = librosa.load(audio_path, sr=self.sr, mono=True)
        waveform = torch.tensor(y.astype(np.float32))

        token_list = self.tokenizer.encode_midi_path(midi_path, max_len=self.max_token_len)
        token_ids = torch.tensor(token_list, dtype=torch.long)

        if self.return_midi_path:
            return waveform, token_ids, midi_path
        return waveform, token_ids

In [None]:
# Initialize dataset
dataset = MaestroDataset(
    root_dir="/kaggle/input/themaestrodatasetv2/maestro-v2.0.0",
    year="2017",
    subset_size=5  # for quick loading
)

# Preview number of samples
print(f"Total samples: {len(dataset)}")

In [None]:
#EventMIDITokenizer
from typing import List, Tuple, Dict, Optional, Union
import pretty_midi

class EventMIDITokenizer:
    """
    Event-based tokenizer (REMI-ish but simpler):
      - 0: <sos>
      - 1: <eos>
      - 2: <pad>

      - NOTE_ON(pitch):  note_on_base + pitch          (pitch 0..127)
      - NOTE_OFF(pitch): note_off_base + pitch         (pitch 0..127)
      - TIME_SHIFT(k):   time_shift_base + (k-1)       (k=1..max_time_shift)
    """

    def __init__(self, vocab_size: int = 512, frame_rate: int = 100, max_time_shift: int = 100):
        self.vocab_size = vocab_size
        self.frame_rate = int(frame_rate)
        self.max_time_shift = int(max_time_shift)

        self.sos = 0
        self.eos = 1
        self.pad = 2

        # Keep gaps for safety/readability
        self.note_on_base = 10           # 10..137
        self.note_off_base = 160         # 160..287
        self.time_shift_base = 320       # 320..(320+max_time_shift-1)

        # sanity
        needed = self.time_shift_base + self.max_time_shift
        if needed > self.vocab_size:
            raise ValueError(f"vocab_size too small; need >= {needed}, got {self.vocab_size}")

    def note_on_id(self, pitch: int) -> int:
        return self.note_on_base + int(pitch)

    def note_off_id(self, pitch: int) -> int:
        return self.note_off_base + int(pitch)

    def time_shift_id(self, k: int) -> int:
        k = int(max(1, min(self.max_time_shift, k)))
        return self.time_shift_base + (k - 1)

    def is_note_on(self, tok: int) -> bool:
        return self.note_on_base <= tok < self.note_on_base + 128

    def is_note_off(self, tok: int) -> bool:
        return self.note_off_base <= tok < self.note_off_base + 128

    def is_time_shift(self, tok: int) -> bool:
        return self.time_shift_base <= tok < self.time_shift_base + self.max_time_shift

    def tok_to_pitch(self, tok: int) -> int:
        if self.is_note_on(tok):
            return tok - self.note_on_base
        if self.is_note_off(tok):
            return tok - self.note_off_base
        raise ValueError("Not a pitch token")

    def tok_to_shift(self, tok: int) -> int:
        return (tok - self.time_shift_base) + 1

    def encode_midi_path(self, midi_path: str, max_len: int = 512) -> List[int]:
        pm = pretty_midi.PrettyMIDI(midi_path)
        return self.encode_pretty_midi(pm, max_len=max_len)

    def encode_midi(self, midi: Union[str, pretty_midi.PrettyMIDI], max_len: int = 512) -> List[int]:
        """
        Backwards compatible: accept either a path or a PrettyMIDI object.
        """
        if isinstance(midi, str):
            pm = pretty_midi.PrettyMIDI(midi)
        elif isinstance(midi, pretty_midi.PrettyMIDI):
            pm = midi
        else:
            raise TypeError(f"encode_midi expected str or PrettyMIDI, got {type(midi)}")
        return self.encode_pretty_midi(pm, max_len=max_len)

    def encode_pretty_midi(self, pm: pretty_midi.PrettyMIDI, max_len: int = 512) -> List[int]:
        # Collect note on/off events from ALL instruments
        events: List[Tuple[int, int, int]] = []  # (frame, kind, pitch) kind: 0=off,1=on
        for inst in pm.instruments:
            if inst.is_drum:
                continue
            for n in inst.notes:
                on_f = int(round(n.start * self.frame_rate))
                off_f = int(round(n.end * self.frame_rate))
                if off_f <= on_f:
                    off_f = on_f + 1
                pitch = int(n.pitch)
                if 0 <= pitch <= 127:
                    events.append((on_f, 1, pitch))
                    events.append((off_f, 0, pitch))

        # Sort: by time, then OFF before ON
        events.sort(key=lambda x: (x[0], x[1]))  # kind: 0(off) then 1(on)

        seq = [self.sos]
        cur_f = 0

        def emit_shift(delta: int):
            nonlocal seq
            while delta > 0 and len(seq) < max_len - 1:
                k = min(self.max_time_shift, delta)
                seq.append(self.time_shift_id(k))
                delta -= k

        for f, kind, pitch in events:
            if len(seq) >= max_len - 1:
                break

            delta = f - cur_f
            if delta > 0:
                emit_shift(delta)
                cur_f = f

            if len(seq) >= max_len - 1:
                break

            seq.append(self.note_off_id(pitch) if kind == 0 else self.note_on_id(pitch))

        seq.append(self.eos)

        if len(seq) < max_len:
            seq += [self.pad] * (max_len - len(seq))
        else:
            seq = seq[:max_len]
        return seq
    

    def decode_to_pretty_midi(self, tokens: List[int], out_path: str) -> str:
        pm = pretty_midi.PrettyMIDI()
        inst = pretty_midi.Instrument(program=0)

        t_f = 0
        active: Dict[int, int] = {}  # pitch -> start_frame

        for tok in tokens:
            tok = int(tok)
            if tok in (self.sos, self.pad):
                continue
            if tok == self.eos:
                break

            if self.is_time_shift(tok):
                t_f += self.tok_to_shift(tok)
                continue

            if self.is_note_on(tok):
                p = self.tok_to_pitch(tok)
                # if already active, ignore or restart; ignore is safest
                if p not in active:
                    active[p] = t_f
                continue

            if self.is_note_off(tok):
                p = self.tok_to_pitch(tok)
                if p in active:
                    start_f = active.pop(p)
                    start = start_f / self.frame_rate
                    end = max((t_f / self.frame_rate), start + (1.0 / self.frame_rate))
                    inst.notes.append(pretty_midi.Note(velocity=80, pitch=p, start=start, end=end))
                continue

        # close hanging notes
        end_f = t_f
        for p, start_f in active.items():
            start = start_f / self.frame_rate
            end = max((end_f / self.frame_rate), start + (1.0 / self.frame_rate))
            inst.notes.append(pretty_midi.Note(velocity=80, pitch=p, start=start, end=end))

        pm.instruments.append(inst)
        pm.write(out_path)
        return out_path

In [None]:
# transcription_model
class TranscriptionModel(nn.Module):
    """
    High-level wrapper for automatic music transcription models.
    Handles:
      - forward pass
      - loss computation
    """

    def __init__(
        self,
        model_type: str = "transformer",
        device: str = "cpu",
        vocab_size: int = 512,
    ):
        super().__init__()

        self.model_type = model_type.lower()
        self.device = device
        self.vocab_size = vocab_size

        if self.model_type in ["ast", "transformer", "audio_transformer"]:
            self.model = ASTModel(device=device, max_output_len=1024, remi_vocab_size=vocab_size)
        else:
            raise ValueError(f"Unknown model type: {model_type}")

        # --- Tokenizer (ONE instance, reused everywhere) ---
        self.tokenizer = EventMIDITokenizer(vocab_size=self.vocab_size)

        # --- Weighted CE loss (REPLACES old CrossEntropyLoss) ---
        self.criterion = self._build_weighted_criterion()

        self.to(device)

    # ------------------------------------------------------------
    # Weighted loss (THIS replaces your old nn.CrossEntropyLoss)
    # ------------------------------------------------------------
    def _build_weighted_criterion(self):
        weights = torch.ones(self.vocab_size, device=self.device)
        
        tok = self.tokenizer
        
        # --------------------------------------------------
        # TIME_SHIFT tokens: [time_shift_base, base + max_time_shift)
        # --------------------------------------------------
        ts_start = tok.time_shift_base
        ts_end = min(tok.time_shift_base + tok.max_time_shift, self.vocab_size)
        
        weights[ts_start:ts_end] = 0.3
        
        # --------------------------------------------------
        # NOTE_ON tokens: [note_on_base, note_off_base)
        # --------------------------------------------------
        no_start = tok.note_on_base
        no_end = min(tok.note_off_base, self.vocab_size)
        
        weights[no_start:no_end] = 2.0
        
        # --------------------------------------------------
        # Never predict SOS
        # --------------------------------------------------
        weights[tok.sos] = 0.0
        weights[tok.eos] = 0.0

        
        return nn.CrossEntropyLoss(
            weight=weights,
            ignore_index=tok.pad,
        )


    # ------------------------------------------------------------
    # Forward (delegates to ASTModel)
    # ------------------------------------------------------------
    def forward(
        self,
        x,
        sampling_rate: int = 16000,
        targets: torch.LongTensor | None = None,
        generate_max_len: int = 256,
    ):
        """
        For AST:
          - If targets is provided → training mode
            returns (logits, tgt_out)
          - If targets is None → generation mode
            returns generated token IDs
        """
        return self.model(
            x,
            targets=targets,
            # generate_max_len=generate_max_len,
        )

    @torch.no_grad()
    def generate_from_audio(self, audio, max_len=256):
        return self.model.generate_from_audio(
            audio=audio,
            max_len=max_len,
            sos_token=self.tokenizer.sos,
            eos_token=self.tokenizer.eos,
        )

    
    # ------------------------------------------------------------
    # LOSS COMPUTATION (FIXED )
    # ------------------------------------------------------------
    def compute_loss(self, model_output, targets=None):
        """
        AST training:
          model_output = (logits, tgt_out)
          logits:  (B, T-1, V)
          tgt_out: (B, T-1)
        """

        if self.model_type in ["ast", "transformer", "audio_transformer"]:
            logits, tgt_out = model_output

            B, T, V = logits.shape
            assert V == self.vocab_size, "Vocab size mismatch"

            loss = self.criterion(
                logits.reshape(-1, V),
                tgt_out.reshape(-1),
            )
            return loss

        raise NotImplementedError("Only AST model supported here")

    # ------------------------------------------------------------
    # Prediction (generation)
    # ------------------------------------------------------------
    @torch.no_grad()
    def predict(self, x, sampling_rate: int = 16000, max_len: int = 256):
        """
        Generates REMI tokens using autoregressive decoding.
        """
        return self.forward(
            x,
            targets=None,
            generate_max_len=max_len,
        )

In [None]:
# train_transcriber

import torch.optim as optim
from torch.amp import autocast, GradScaler
from tqdm import tqdm
from torch.amp import autocast
from functools import partial
from typing import List, Tuple


def collate_ast_tokens(batch, pad_id: int = 2):
    """
    batch: list of (waveform_tensor_1d, token_ids_1d_long)
    returns: list_of_waveforms, token_tensor (B, L)
    """
    waveforms, token_seqs = zip(*batch)

    # ensure tokens are Long
    token_seqs = [t.to(torch.long) for t in token_seqs]

    max_len = max(t.numel() for t in token_seqs)
    B = len(token_seqs)

    out = torch.full((B, max_len), pad_id, dtype=torch.long)
    for i, t in enumerate(token_seqs):
        out[i, : t.numel()] = t

    return list(waveforms), out

def assert_tokens_ok(token_targets: torch.Tensor, vocab_size: int, pad_id: int = 2):
    # token_targets: (B, T)
    if token_targets.dtype != torch.long:
        raise TypeError(f"token_targets must be torch.long, got {token_targets.dtype}")
    mn = int(token_targets.min().item())
    mx = int(token_targets.max().item())
    if mn < 0 or mx >= vocab_size:
        bad = (token_targets < 0) | (token_targets >= vocab_size)
        idx = bad.nonzero(as_tuple=False)[:10].tolist()
        raise ValueError(f"Token id out of range: min={mn}, max={mx}, vocab={vocab_size}. Examples idx={idx}")
    # optional: ensure pad is in range
    if not (0 <= pad_id < vocab_size):
        raise ValueError("pad_id out of range")


def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    scaler = GradScaler()

    total_loss = 0.0
    step_losses = []

    # Initialize tqdm progress bar
    progress_bar = tqdm(dataloader, desc="Training", leave=False)
    
    # handle AST vs framewise training
    if getattr(model, "model_type", "cnn_rnn") in ["ast", "transformer", "audio_transformer"]:
        for batch in progress_bar:
            # Support variable batch structures (some collate fns may return 2- or 3-tuples)
            optimizer.zero_grad(set_to_none=True)

            # Unpack robustly
            if isinstance(batch, (list, tuple)) and len(batch) == 2:
                waveforms, token_targets = batch
            elif isinstance(batch, (list, tuple)) and len(batch) == 3:
                # some collates may return (waveforms, tokens, extra)
                waveforms, token_targets, _ = batch
            else:
                # fallback: assume batch is a dict-like from HF datasets
                try:
                    waveforms = batch[0]
                    token_targets = batch[1]
                except Exception:
                    raise ValueError("Unsupported batch format for AST training: %r" % (type(batch),))

            # Move targets to device; waveforms will be processed by model
            token_targets = token_targets.to(device)

            # tokens = token_targets[0]
            # unique, counts = torch.unique(tokens, return_counts=True)
            # print(dict(zip(unique.tolist(), counts.tolist())))

            
            # # Mixed-precision forward/backward
            # with autocast('cuda'):
            #     model_output = model(waveforms, targets=token_targets)
            #     loss = model.compute_loss(model_output, token_targets)


            # In train_one_epoch / evaluate, before model(waveforms,...):
            assert_tokens_ok(token_targets, vocab_size=512, pad_id=2)

            # No AMP
            model_output = model(waveforms, targets=token_targets)
            loss = model.compute_loss(model_output, token_targets)


            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            for name, p in model.named_parameters():
                if p.grad is not None:
                    print(f"{name}: grad mean = {p.grad.abs().mean().item():.6f}")
                    break

            step_loss = loss.item()
            total_loss += step_loss
            step_losses.append(step_loss)
            progress_bar.set_postfix({"step_loss": f"{step_loss:.4f}"})
    else:
        for mel, roll, lengths in progress_bar:
            mel, roll = mel.to(device), roll.to(device)

            optimizer.zero_grad(set_to_none=True)

            # Mixed precision forward
            with autocast('cuda'):
                logits = model(mel)
                loss = model.compute_loss(logits, roll, lengths)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            # Update running totals
            step_loss = loss.item()
            total_loss += step_loss
            step_losses.append(step_loss)

            # Update tqdm bar dynamically
            progress_bar.set_postfix({"step_loss": f"{step_loss:.4f}"})

        # Update running totals
        step_loss = loss.item()
        total_loss += step_loss
        step_losses.append(step_loss)

        # Update tqdm bar dynamically
        progress_bar.set_postfix({"step_loss": f"{step_loss:.4f}"})

    avg_loss = total_loss / len(dataloader)
    progress_bar.close()
    return avg_loss, step_losses

@torch.no_grad()
def evaluate(model, dataloader, device, max_batches: int = None):
    model.eval()
    total_loss = 0.0
    n = 0

    for batch_i, batch in enumerate(tqdm(dataloader, desc="Validation", leave=False)):
        if max_batches is not None and batch_i >= max_batches:
            break

        if isinstance(batch, (list, tuple)) and len(batch) == 2:
            waveforms, token_targets = batch
        elif isinstance(batch, (list, tuple)) and len(batch) == 3:
            waveforms, token_targets, _ = batch
        else:
            waveforms, token_targets = batch[0], batch[1]

        assert_tokens_ok(token_targets, vocab_size=512, pad_id=2)

        token_targets = token_targets.to(device)
        logits = model(waveforms, targets=token_targets)
        loss = model.compute_loss(logits, token_targets)

        total_loss += loss.item()
        n += 1

    return total_loss / max(1, n)

In [None]:
# AST Model

try:
    # Hugging Face imports (optional). If not installed, init will raise an informative error.
    from transformers import AutoFeatureExtractor, AutoModel
    _HF_AVAILABLE = True
except Exception:
    AutoFeatureExtractor = None
    AutoModel = None
    _HF_AVAILABLE = False

class ASTModel(nn.Module):
    """
    Audio Spectrogram Transformer encoder + Transformer decoder for REMI token generation.

    - Uses a pretrained AST encoder from Hugging Face (specified by `pretrained_model_name`).
    - Initially freezes encoder weights by default to reduce compute.
    - Transformer decoder generates REMI tokens autoregressively (teacher forcing during training).

    Input to forward():
      - waveforms: Tensor[B, L] or list of 1D Tensors (raw audio in float, range [-1,1])
      - sampling_rate: int (e.g., 16000)
      - targets (optional): LongTensor[B, T] of token ids for teacher forcing

    Returns:
      - If targets provided -> logits: Tensor[B, T, vocab_size]
      - Else -> generated token ids: Tensor[B, gen_len]
    """

    def __init__(
        self,
        pretrained_model_name: str = "MIT/ast-finetuned-audioset-10-10-0.4593",
        use_mock_encoder: bool = False,
        freeze_encoder: bool = True,
        remi_vocab_size: int = 512,
        decoder_layers: int = 1,    # TODO Change to ~4
        decoder_dim: int = 256,     # TODO change to 384
        decoder_heads: int = 4,     # TODO change to 6
        dropout: float = 0.0,       # TODO change to 0.2
        max_output_len: int = 1024,
        device: str = "cpu",
    ):
        super().__init__()

        self.device = device
        self.pretrained_model_name = pretrained_model_name
        self.freeze_encoder = freeze_encoder
        self.remi_vocab_size = remi_vocab_size
        self.decoder_dim = decoder_dim
        self.max_output_len = max_output_len
        self.use_mock_encoder = use_mock_encoder

        if self.use_mock_encoder:
            # Build a tiny mock feature extractor + encoder for unit tests (no HF download)
            class _MockFeatureExtractor:
                def __init__(self, hidden_size=64):
                    self.hidden_size = hidden_size

                def __call__(self, waveforms, sampling_rate=None, return_tensors="pt", padding=True):
                    # Return a dummy tensor shaped (B, S, hidden)
                    import numpy as _np
                    if isinstance(waveforms, list):
                        B = len(waveforms)
                        max_len = max([w.shape[0] if hasattr(w, 'shape') else len(w) for w in waveforms])
                    else:
                        waveforms = [waveforms]
                        B = 1
                        max_len = waveforms[0].shape[0]
                    S = max(1, max_len // 160)  # coarse time dimension
                    return {"input_values": torch.randn(B, S, self.hidden_size)}

            class _MockEncoder(nn.Module):
                def __init__(self, hidden_size=64):
                    super().__init__()
                    self.config = type("C", (), {"hidden_size": hidden_size})

                def forward(self, **kwargs):
                    x = kwargs.get("input_values")
                    # assume x is (B, S, H)
                    return type("O", (), {"last_hidden_state": x})

            self.feature_extractor = _MockFeatureExtractor(hidden_size=decoder_dim)
            self.encoder = _MockEncoder(hidden_size=decoder_dim)
        else:
            if not _HF_AVAILABLE:
                raise ImportError(
                    "The ASTModel requires `transformers` to be installed. "
                )

            # Feature extractor converts raw audio waveforms to log-mel patches expected by AST
            self.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name)

            # Pretrained AST encoder (we'll freeze its weights by default)
            self.encoder = AutoModel.from_pretrained(pretrained_model_name)

        # encoder hidden size (from model config)
        enc_hidden = getattr(self.encoder.config, "hidden_size", None)
        if enc_hidden is None:
            # fallback: try common attributes
            enc_hidden = getattr(self.encoder.config, "embed_dim", decoder_dim)

        # Freeze encoder if requested
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False

        # Project encoder features to decoder dimensionality
        self.enc_to_dec = nn.Linear(enc_hidden, decoder_dim)
        
        # Decoder token embeddings + positional embeddings
        self.token_emb = nn.Embedding(remi_vocab_size, decoder_dim)
        self.pos_emb = nn.Embedding(max_output_len, decoder_dim)
        
        # Transformer decoder stack
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=decoder_dim,
            nhead=decoder_heads,
            dim_feedforward=decoder_dim * 4,
            dropout=dropout,
            activation="gelu",
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=decoder_layers)

        # Final projection to REMI vocabulary
        self.output_fc = nn.Linear(decoder_dim, remi_vocab_size)

        # initialization helpers
        self._reset_parameters()

        self.to(device)

    def _reset_parameters(self):
        # small init for newly added heads
        nn.init.normal_(self.enc_to_dec.weight, mean=0.0, std=0.02)
        if self.enc_to_dec.bias is not None:
            nn.init.zeros_(self.enc_to_dec.bias)
        nn.init.normal_(self.output_fc.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.output_fc.bias)

    # def _generate_square_subsequent_mask(self, sz: int):
    #     # PyTorch transformer expects float mask with -inf on illegal positions
    #     mask = torch.triu(torch.full((sz, sz), float("-inf")), diagonal=1)
    #     return mask

    # def _generate_square_subsequent_mask(self, sz):
    #     return torch.triu(
    #         torch.ones(sz, sz), diagonal=1
    #     ).bool()

    def _generate_square_subsequent_mask(self, sz: int):
        return torch.triu(
            torch.full((sz, sz), float("-inf")),
            diagonal=1
        )


    def forward(self, audio, targets):
        device = targets.device
        B, T = targets.shape

        # Convert PyTorch tensors to 1D NumPy arrays (required by HF feature extractor)
        converted_audio = []
        for w in audio:
            if isinstance(w, torch.Tensor):
                converted_audio.append(w.detach().cpu().numpy().astype(np.float32))
            else:
                # just ensure dtype float32
                converted_audio.append(np.asarray(w, dtype=np.float32))
        audio = converted_audio
    
        # Feature extraction
        inputs = self.feature_extractor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
        for k in inputs:
            inputs[k] = inputs[k].to(device)
    
        # Encoder
        if self.freeze_encoder:
            with torch.no_grad():
                enc_outputs = self.encoder(**inputs)
        else:
            enc_outputs = self.encoder(**inputs)
    
        encoder_out = enc_outputs.last_hidden_state
        
        # encoder_out: (B, T_enc, 768)
        encoder_out = self.enc_to_dec(encoder_out)  # now (B, T_enc, 384)
        memory = encoder_out.transpose(0, 1)       # (T_enc, B, 384)

        # Decoder
        tgt_in = targets[:, :-1]
        tgt_out = targets[:, 1:]
        T_dec = tgt_in.size(1)
        positions = torch.arange(T_dec, device=device).unsqueeze(0)
        tgt_emb = self.token_emb(tgt_in) + self.pos_emb(positions)
        tgt_emb = tgt_emb * (self.decoder_dim ** 0.5)
        tgt_emb = tgt_emb.transpose(0, 1)
        memory = encoder_out.transpose(0, 1)
        tgt_mask = self._generate_square_subsequent_mask(T_dec).to(device)
        dec_out = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        dec_out = dec_out.transpose(0, 1)
        logits = self.output_fc(dec_out)
    
        return logits, tgt_out


    @torch.no_grad()
    def generate_from_audio(
        self,
        audio,
        max_len=256,
        sos_token=0,
        eos_token=None,
    ):
        self.eval()
        device = next(self.parameters()).device
        B = len(audio)
    
        converted_audio = []
        for w in audio:
            if isinstance(w, torch.Tensor):
                converted_audio.append(w.detach().cpu().numpy().astype(np.float32))
            else:
                converted_audio.append(np.asarray(w, dtype=np.float32))
        audio = converted_audio
    
        inputs = self.feature_extractor(
            audio,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
    
        enc_outputs = self.encoder(**inputs)
        encoder_out = enc_outputs.last_hidden_state
        encoder_out = self.enc_to_dec(encoder_out)
        memory = encoder_out.transpose(0, 1)
    
        # ---------------------------
        # Autoregressive decoding
        # ---------------------------
        generated = torch.full(
            (B, 1), sos_token, dtype=torch.long, device=device
        )
    
        for _ in range(max_len - 1):
            T = generated.size(1)
            positions = torch.arange(T, device=device).unsqueeze(0)
    
            tgt_emb = self.token_emb(generated) + self.pos_emb(positions)
            tgt_emb = tgt_emb * (self.decoder_dim ** 0.5)
            tgt_emb = tgt_emb.transpose(0, 1)
    
            tgt_mask = self._generate_square_subsequent_mask(T).to(device)
            dec_out = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
    
            logits = self.output_fc(dec_out[-1])  # (B, vocab)
            next_token = logits.argmax(dim=-1, keepdim=True)
    
            generated = torch.cat([generated, next_token], dim=1)
    
            if eos_token is not None:
                if (next_token == eos_token).all():
                    break
        return generated
    
    @torch.no_grad()
    def generate(
        self,
        memory,
        sos_id: int = 0,
        max_len: int = 256,
        do_sample: bool = False,
        temperature: float = 1.0,
        top_k: int = 0,
        mask_sos: bool = True,
        repetition_penalty: float = 0.0,
    ):
        """Autoregressive generation from the decoder using provided encoder memory.

        Backwards-compatible defaults produce the previous greedy behavior
        (do_sample=False, mask_sos=False, repetition_penalty=0.0).

        Args:
            memory: (S, B, D) encoder memory
            sos_id: start token id
            max_len: max tokens to generate
            do_sample: whether to sample from the softmax (vs argmax)
            temperature: softmax temperature when sampling
            top_k: if >0, restrict sampling to top_k logits
            mask_sos: if True, forbid emitting sos token after step 0
            repetition_penalty: float >=0. Subtracts penalty*count[token] from logits

        Returns:
            Tensor[B, L] of generated token ids
        """
        device = memory.device
        S, B, D = memory.shape
        vocab_size = self.remi_vocab_size

        generated = torch.full((B, 1), sos_id, dtype=torch.long, device=device)

        # counts per batch item for repetition penalty
        if repetition_penalty and repetition_penalty > 0.0:
            counts = torch.zeros((B, vocab_size), dtype=torch.long, device=device)
            # initialize with sos count
            counts.scatter_add_(1, generated, torch.ones_like(generated, dtype=torch.long))
        else:
            counts = None

        def top_k_logits(logits, k: int):
            if k <= 0:
                return logits
            values, _ = torch.topk(logits, k)
            min_values = values[..., -1, None]
            return torch.where(logits < min_values, torch.full_like(logits, float("-1e9")), logits)

        for step in range(max_len):
            positions = torch.arange(generated.size(1), device=device).unsqueeze(0).expand(B, -1)
            tgt_emb = self.token_emb(generated) + self.pos_emb(positions)
            tgt = tgt_emb.permute(1, 0, 2).contiguous()  # (T, B, D)
            tgt_mask = self._generate_square_subsequent_mask(tgt.size(0)).to(device)
            dec_out = self.decoder(tgt, memory, tgt_mask=tgt_mask)  # (T, B, D)
            last = dec_out[-1]  # (B, D)
            logits = self.output_fc(last)  # (B, V)

            # Optionally forbid producing sos after the first position
            if mask_sos and step > 0:
                if 0 <= sos_id < logits.size(-1):
                    logits[:, sos_id] = float("-1e9")

            # Apply repetition penalty (simple count-based subtraction)
            if counts is not None:
                # subtract penalty * counts (counts is integer tensor)
                logits = logits - repetition_penalty * counts.float()

            if do_sample:
                # sampling path: apply temperature and top_k filtering
                sample_logits = logits / max(1e-8, float(temperature))
                if top_k > 0:
                    sample_logits = top_k_logits(sample_logits, top_k)
                probs = torch.softmax(sample_logits, dim=-1)
                next_tokens = torch.multinomial(probs, num_samples=1)
            else:
                # greedy argmax
                next_tokens = logits.argmax(dim=-1, keepdim=True)

            # update counts if used
            if counts is not None:
                counts.scatter_add_(1, next_tokens, torch.ones_like(next_tokens, dtype=torch.long))

            generated = torch.cat([generated, next_tokens], dim=1)

        return generated[:, 1:]

In [None]:
# main.py

import argparse
from datetime import datetime


# Plot: per-epoch averaged losses
def plot_training_curves(train_losses, val_losses, save_path):
    """Plot and save loss curves."""
    plt.figure(figsize=(8, 5))
    plt.plot(train_losses, label="Train Loss", color="royalblue", linewidth=2)
    plt.plot(val_losses, label="Val Loss", color="tomato", linewidth=2)
    plt.title("Training and Validation Loss", fontsize=14)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

# Plot: per-step losses with epoch boundaries
def plot_step_losses(global_step_losses, num_epochs, save_path):
    """
    Plot per-step training loss across all epochs.

    Args:
        global_step_losses: list of lists, where each sublist = losses for that epoch
        num_epochs: total number of epochs (for x-axis scaling)
        save_path: file path for saving the figure
    """
    plt.figure(figsize=(10, 5))
    flat_losses = np.concatenate(global_step_losses)
    plt.plot(flat_losses, color="mediumseagreen", linewidth=1.2)
    plt.title("Training Loss per Step", fontsize=14)
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.grid(True, alpha=0.3)

    # Add vertical lines to mark epoch boundaries
    step = 0
    for i, epoch_losses in enumerate(global_step_losses, 1):
        step += len(epoch_losses)
        plt.axvline(x=step, color="gray", linestyle="--", alpha=0.3)
        plt.text(step, plt.ylim()[1]*0.95, f"Epoch {i}", rotation=90,
                 fontsize=8, color="gray", va="top", ha="right")

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
    
def main(args):
    parser = argparse.ArgumentParser(description="Train music transcription model")
    parser.add_argument("--root_dir", type=str, default="maestro-v2.0.0", help="Path to MAESTRO dataset root")
    parser.add_argument("--year", type=str, default=None, help="Subset year (e.g. 2017). Deprecated: prefer --years")
    parser.add_argument("--years", type=str, default=None, help="Comma-separated list of years to include (e.g. '2013,2017')")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size")
    parser.add_argument("--epochs", type=int, default=25, help="Number of training epochs")
    parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate")
    parser.add_argument("--subset_size", type=int, default=None, help="Limit dataset size for debugging")
    parser.add_argument("--model_type", type=str, default="cnn_rnn", help="Model type: cnn_rnn or ast")
    parser.add_argument("--save_every", type=int, default=10, help="Save model checkpoint every N epochs")
    args = parser.parse_args(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    # print(torch.cuda.memory_summary(device=device, abbreviated=True)) GPU memory debugging

    # Output structure ---
    timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    run_dir = os.path.join("outputs", timestamp)
    checkpoint_dir = os.path.join(run_dir, "checkpoints")
    checkpoint_dir = os.path.join("/kaggle/output/checkpoints", timestamp)
    logs_dir = os.path.join(run_dir, "logs")
    os.makedirs(checkpoint_dir, exist_ok=True)
    os.makedirs(logs_dir, exist_ok=True)

    model_path = os.path.join(checkpoint_dir, "model_final.pth")
    loss_plot_path = os.path.join(logs_dir, "loss_curve.png")
    step_plot_path = os.path.join(logs_dir, "loss_per_step.png")

    # Data
    # return_waveform = args.model_type in ["ast", "transformer", "audio_transformer"]

    # Support multiple years via --years (comma-separated) or single --year (backwards compat)
    years_arg = None
    if args.years:
        years_arg = args.years
    elif args.year:
        years_arg = args.year

     # create tokenizer once
    tokenizer = EventMIDITokenizer(vocab_size=512)  # whatever args you use
    
    train_ds = MaestroASTDataset(
        root_dir=args.root_dir,
        tokenizer=tokenizer,
        year=years_arg,
        split="train",
        subset_size=args.subset_size,
        max_token_len=1024,
    )
    
    val_ds = MaestroASTDataset(
        root_dir=args.root_dir,
        tokenizer=tokenizer,
        year=years_arg,
        split="validation",
        subset_size=(args.subset_size // 5 if args.subset_size else None),
        max_token_len=1024,
    )
    
    # Nicely formatted parameters.txt
    with open(os.path.join(logs_dir, "parameters.txt"), "w") as file:
        file.write("=== Training Parameters ===\n")
        file.write(f"Timestamp: {timestamp}\n")
        file.write(f"Device: {device}\n\n")
        for k, v in vars(args).items():
            file.write(f"{k:>15}: {v}\n")
        file.write(f"Training dataset size: {len(train_ds)}\n")
        file.write(f"Validation dataset size: {len(val_ds)}\n")

    # used_collate = collate_ast
    
    # On Kaggle kernels multiprocessing can be problematic; detect and use 0 workers there
    if os.environ.get('KAGGLE_KERNEL_RUN_TYPE') is not None:
        num_workers = 0
    else:
        num_workers = min(4, (os.cpu_count() or 1))
    pin_memory = True if device.type == 'cuda' else False
    
    used_collate = partial(collate_ast_tokens, pad_id=tokenizer.pad_id if hasattr(tokenizer, "pad_id") else 2)
    
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, collate_fn=used_collate, num_workers=num_workers, pin_memory=pin_memory)
    val_loader   = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, collate_fn=used_collate, num_workers=num_workers, pin_memory=pin_memory)

    model = TranscriptionModel(model_type=args.model_type, device=device)

    # Ensure model is on the correct device
    model.to(device)
    # Only optimize parameters that require gradients (encoder may be frozen)
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    if len(trainable_params) == 0:
        trainable_params = model.parameters()
    optimizer = optim.AdamW(trainable_params, lr=args.lr, weight_decay=0)  # TODO change weight decay to 1e-2

    # Logging
    train_losses, val_losses = [], []
    global_losses = []  # list of per-step lists
    log_txt_path = os.path.join(logs_dir, "training_log.txt")

    with open(log_txt_path, "w") as log_file:
        log_file.write(f"Training started: {timestamp}\n")
        log_file.write(f"Device: {device}\n")
        log_file.write(f"Epochs: {args.epochs}, Batch size: {args.batch_size}, LR: {args.lr}\n\n")

        # Training loop
        for epoch in range(1, args.epochs + 1):
            print(f"\nEpoch {epoch}/{args.epochs}")
            # print(torch.cuda.memory_summary(device=device, abbreviated=True)) #GPU memory debugging

            train_loss, step_losses = train_one_epoch(model, train_loader, optimizer, device)
            val_loss = evaluate(model, val_loader, device, max_batches=1)

            train_losses.append(train_loss)
            val_losses.append(val_loss)
            global_losses.append(step_losses)

            print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
            log_file.write(f"Epoch {epoch:02d}: Train={train_loss:.4f}, Val={val_loss:.4f}\n")
            log_file.flush()

            # Update training curve
            plot_training_curves(train_losses, val_losses, loss_plot_path)
            plot_step_losses(global_losses, args.epochs, step_plot_path)

            # Save checkpoint periodically
            if epoch % args.save_every == 0 or epoch == args.epochs:
                ckpt_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch}.pth")
                torch.save(model.state_dict(), ckpt_path)
                print(f"Checkpoint saved to {ckpt_path}")

        # Final save
        torch.save(model.state_dict(), model_path)
        print(f"\nFinal model saved to {model_path}")
        log_file.write("\nTraining complete.\n")
    return model

In [None]:
# Driver code -- Training
model = main(["--root_dir", "/kaggle/input/themaestrodatasetv2/maestro-v2.0.0",
      "--year", "2017",
      #"--years", "2004,2006,2008,2009,2011,2013,2014,2015,2017,2018", 
      "--epochs", "50", 
      "--batch_size", "1",
      "--lr", "1e-3",     #TODO change to 1e-4
      "--save_every", "52",
      "--subset_size", "1", #"2000",
      "--model_type", "transformer"])

In [None]:
# Evaluation driver
import os
import numpy as np
import torch
import pretty_midi

def _trim_tokens(token_list, eos_id=1, pad_id=2):
    out = []
    for t in token_list:
        t = int(t)
        if t == pad_id:
            continue
        out.append(t)
        if t == eos_id:
            break
    return out

def midi_to_note_events(pm: pretty_midi.PrettyMIDI):
    """Return list of (pitch, onset_sec, offset_sec) across non-drum instruments."""
    notes = []
    for inst in pm.instruments:
        if inst.is_drum:
            continue
        for n in inst.notes:
            notes.append((int(n.pitch), float(n.start), float(n.end)))
    # sort by onset for stable matching
    notes.sort(key=lambda x: (x[1], x[0]))
    return notes

def note_f1(pred_notes, ref_notes, onset_tol=0.05, pitch_tol=0):
    """
    Simple greedy matching:
      match if |onset_pred - onset_ref| <= onset_tol and |pitch_pred - pitch_ref| <= pitch_tol
    """
    used = np.zeros(len(ref_notes), dtype=bool)
    tp = 0

    for (pp, ps, pe) in pred_notes:
        best_j = -1
        best_dt = None
        for j, (rp, rs, re) in enumerate(ref_notes):
            if used[j]:
                continue
            if abs(pp - rp) > pitch_tol:
                continue
            dt = abs(ps - rs)
            if dt <= onset_tol:
                if best_dt is None or dt < best_dt:
                    best_dt = dt
                    best_j = j
        if best_j >= 0:
            used[best_j] = True
            tp += 1

    fp = len(pred_notes) - tp
    fn = len(ref_notes) - tp

    prec = tp / (tp + fp + 1e-9)
    rec  = tp / (tp + fn + 1e-9)
    f1   = 2 * prec * rec / (prec + rec + 1e-9)
    return {"tp": tp, "fp": fp, "fn": fn, "precision": prec, "recall": rec, "f1": f1}

@torch.no_grad()
def evaluate_one_sample(model, tokenizer, waveform, gt_token_ids, out_dir="/kaggle/working",
                        max_len=1024, onset_tol=0.1, pitch_tol=0.5):
    os.makedirs(out_dir, exist_ok=True)

    model.eval()

    # Ensure waveform is CPU tensor (your generate_from_audio converts -> numpy internally)
    wav = waveform.detach().cpu()

    # Generate
    gen = model.generate_from_audio(audio=[wav], max_len=max_len)
    gen_tokens = gen[0].detach().cpu().tolist()

    gt_tokens = gt_token_ids.detach().cpu().tolist()

    gen_tokens = _trim_tokens(gen_tokens, eos_id=tokenizer.eos, pad_id=tokenizer.pad)
    gt_tokens  = _trim_tokens(gt_tokens,  eos_id=tokenizer.eos, pad_id=tokenizer.pad)

    # Decode both to MIDI
    gen_mid_path = os.path.join(out_dir, "gen.mid")
    gt_mid_path  = os.path.join(out_dir, "gt.mid")

    tokenizer.decode_to_pretty_midi(gen_tokens, gen_mid_path)
    tokenizer.decode_to_pretty_midi(gt_tokens,  gt_mid_path)

    # Load and score
    gen_pm = pretty_midi.PrettyMIDI(gen_mid_path)
    gt_pm  = pretty_midi.PrettyMIDI(gt_mid_path)

    gen_notes = midi_to_note_events(gen_pm)
    gt_notes  = midi_to_note_events(gt_pm)

    metrics = note_f1(gen_notes, gt_notes, onset_tol=onset_tol, pitch_tol=pitch_tol)

    return {
        "gen_mid_path": gen_mid_path,
        "gt_mid_path": gt_mid_path,
        "n_gen_notes": len(gen_notes),
        "n_gt_notes": len(gt_notes),
        **metrics
    }

# -----------------------
# Example usage
# -----------------------
ds = MaestroASTDataset(root_dir="/kaggle/input/themaestrodatasetv2/maestro-v2.0.0",
    split="train",
    year="2017",
    subset_size=1,
    tokenizer=tok, 
    max_token_len=1024)
waveform, token_ids = ds[0]
metrics = evaluate_one_sample(model, tok, waveform, token_ids, max_len=1024)
print(metrics)
print("GT MIDI:", metrics["gt_mid_path"])
print("GEN MIDI:", metrics["gen_mid_path"])

In [None]:
# Working sheet music display

# --- Install MuseScore so music21 can render sheet music ---
!apt-get update -qq
!apt-get install -y musescore -qq

# --- Import libraries ---
from music21 import converter, environment
from IPython.display import Image, display

# Tell music21 where MuseScore is
us = environment.UserSettings()
us['musescoreDirectPNGPath'] = '/usr/bin/mscore'
us['musicxmlPath'] = '/usr/bin/mscore'

# --- Load MIDI file ---
midi_path = "/kaggle/working/generated_from_audio.mid"
# midi_path = "/kaggle/input/themaestrodatasetv2/maestro-v2.0.0/2004/MIDI-Unprocessed_SMF_16_R1_2004_01-08_ORIG_MID--AUDIO_16_R1_2004_13_Track13_wav.midi"
score = converter.parse(midi_path)

# --- Convert to PNG sheet music ---
png_path = score.write('musicxml.png')

# --- Display the resulting sheet music image ---
display(Image(filename=png_path))