In [None]:
import os
import random
from typing import List, Optional

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ------------ Dataset -------------
# Replace with the folder that contains **multi‑instrument** MIDI files.
# For quick tests you can point to the Lakh Cleaned dataset root.

# https://www.kaggle.com/datasets/imsparsh/lakh-midi-clean?resource=download
DATA_ROOT = "lakh_midi"

# If you have a CSV with train/val/test splits, load it here; otherwise we create one.
SPLIT_CSV = None  # path to optional CSV with columns [filepath, split]

MAX_SEQ_LEN = 1024           # split long pieces into chunks of this many tokens
BATCH_SIZE  = 4
NUM_EPOCHS  = 20
LEARNING_RATE = 0.001

EMBED_DIM  = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2

In [None]:
# REMI with Program‑Change events → single flattened event stream
config = TokenizerConfig(
    use_programs=True,      # ← this is the key flag
    one_token_stream=True,  # keeps your single‑stream RNN forward
)

tokenizer = REMI(config)

# Helper set for fast lookup during constrained sampling
PROGRAM_TOKEN_IDS = {
    p: tokenizer[f"Program_{p}"]
    for p in range(128)
    if f"Program_{p}" in tokenizer  # safety for unseen programs
}

In [None]:
def build_split_lists(root: str, csv: Optional[str] = None):
    """Return lists of MIDI paths grouped by split."""
    if csv:
        meta = pd.read_csv(csv)
        train = meta.loc[meta["split"] == "train", "filepath"].tolist()
        val   = meta.loc[meta["split"] == "validation", "filepath"].tolist()
        test  = meta.loc[meta["split"] == "test", "filepath"].tolist()
    else:
        # Simple 80‑10‑10 random split over *.mid files in the root
        all_midi = [os.path.join(dp, f)
                    for dp, _, files in os.walk(root)
                    for f in files if f.lower().endswith(".mid") or f.lower().endswith(".midi")]
        random.shuffle(all_midi)
        n = len(all_midi)
        train, val, test = (
            all_midi[: int(0.8 * n)],
            all_midi[int(0.8 * n): int(0.9 * n)],
            all_midi[int(0.9 * n):],
        )
    return train, val, test

In [None]:
train_files, val_files, test_files = build_split_lists(DATA_ROOT, SPLIT_CSV)

train_dataset = DatasetMIDI(
    files_paths=train_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
val_dataset = DatasetMIDI(
    files_paths=val_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)

collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collator)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

In [None]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

In [None]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=20, lr=0.001, device=DEVICE):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        for batch in train_loader:
            batch = batch["input_ids"].to(device)
            inputs, targets = batch[:, :-1], batch[:, 1:]
            optimizer.zero_grad()
            logits, _ = model(inputs)
            loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train = total_train_loss / len(train_loader)

        # ---- validation ----
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch["input_ids"].to(device)
                inputs, targets = batch[:, :-1], batch[:, 1:]
                logits, _ = model(inputs)
                loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
                total_val_loss += loss.item()
        avg_val = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} — train: {avg_train:.4f} | val: {avg_val:.4f}")

In [None]:
def sample_with_programs(
    model: nn.Module,
    tokenizer: REMI,
    allowed_programs: List[int],
    max_length: int = 1024,
    temperature: float = 1.0,
    device: str = DEVICE,
):
    """Generate a token sequence that uses ONLY the requested GM program numbers.

    Parameters
    ----------
    allowed_programs : list[int]
        List of MIDI program numbers (0‑127) the piece may contain. Example: [0] for solo
        piano, [40, 41, 42] for a string trio.
    """

    model.to(device)
    model.eval()

    # Build a mask over the vocabulary: 1 for allowed tokens, 0 to ban.
    allow_ids = {PROGRAM_TOKEN_IDS[p] for p in allowed_programs if p in PROGRAM_TOKEN_IDS}
    vocab_size = tokenizer.vocab_size
    mask = torch.ones(vocab_size, device=device)
    for pid in PROGRAM_TOKEN_IDS.values():
        if pid not in allow_ids:
            mask[pid] = 0.0
    
    # Seed sequence: <BOS> + first Program token (choose first allowed)
    first_prog = allowed_programs[0]
    generated = [tokenizer["BOS_None"], PROGRAM_TOKEN_IDS[first_prog]]

    input_tok = torch.tensor([[generated[-1]]], device=device)
    hidden = None

    for _ in range(max_length):
        logits, hidden = model(input_tok, hidden)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1) * mask  # zero out banned programs
        if probs.sum() == 0:
            # catastrophic: mask wiped all mass, fall back to uniform over allowed
            probs = mask / mask.sum()
        else:
            probs = probs / probs.sum()  # renorm
        next_tok = torch.multinomial(probs, 1).item()
        generated.append(next_tok)
        if next_tok == tokenizer["EOS_None"]:
            break
        input_tok = torch.tensor([[next_tok]], device=device)

    return generated

In [None]:
vocab = tokenizer.vocab_size
model = MusicRNN(vocab, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS)

print("→ starting training on", len(train_files), "files (multi‑instrument)")
train(model, train_loader, val_loader, vocab)

# ---- sampling demo : string quartet (Violin, Viola, Cello)
string_programs = [40, 41, 42]  # GM numbers for Violin, Viola, Cello
tokens = sample_with_programs(model, tokenizer, string_programs, max_length=2048)
midi = tokenizer.tokens_to_midi([tokens])
midi.dump_midi("sample_quartet.mid")
print("Saved sample_quartet.mid (", len(tokens), "tokens )")