In [1]:
import os
import random
from typing import List, Optional

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator

from collections import defaultdict
from symusic import Score

from pretty_midi import PrettyMIDI

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

# ------------ Dataset -------------
# Download from: https://www.kaggle.com/datasets/imsparsh/lakh-midi-clean?resource=download
# Unzip and name dataset folder root "lakh-midi"
DATA_ROOT = "lakh-midi"

# If you have a CSV with train/val/test splits, load it here; otherwise we create one.
SPLIT_CSV = None  # path to optional CSV with columns [filepath, split]

MAX_SEQ_LEN = 1024           # split long pieces into chunks of this many tokens
BATCH_SIZE  = 4
NUM_EPOCHS  = 20
LEARNING_RATE = 0.001

EMBED_DIM  = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2

cuda


In [3]:
def collect_program_counts(root: str):
    counts = defaultdict(int)
    midi_files = [os.path.join(dp, f)
                  for dp, _, files in os.walk(root)
                  for f in files if f.lower().endswith((".mid", ".midi"))]
    for fp in midi_files:
        try:
            score = Score(fp)
            for track in score.tracks:
                if track.is_drum:                 # skip channel-10 drums
                    continue
                counts[track.program] += 1
        except Exception:
            continue                              # skip unreadable files
    return counts, midi_files

program_counts, all_midi_files = collect_program_counts(DATA_ROOT)
TOP20 = sorted(program_counts.items(), key=lambda kv: kv[1], reverse=True)[:20]
allowed_programs = [p for p, _ in TOP20]

print("Top-20 instruments:")
for prog, cnt in TOP20:
    print(f"  Program {prog:3d}: {cnt} tracks")

Top-20 instruments:
  Program   0: 13455 tracks
  Program  25: 6211 tracks
  Program  29: 6127 tracks
  Program  48: 6038 tracks
  Program  33: 5848 tracks
  Program  30: 5434 tracks
  Program  27: 5334 tracks
  Program  35: 4211 tracks
  Program  52: 3894 tracks
  Program  26: 3379 tracks
  Program  28: 3315 tracks
  Program  49: 3258 tracks
  Program  24: 3250 tracks
  Program  32: 2877 tracks
  Program   1: 2780 tracks
  Program  53: 2542 tracks
  Program  65: 2420 tracks
  Program  50: 2389 tracks
  Program  61: 2293 tracks
  Program  73: 2135 tracks


In [4]:
# REMI with Program‑Change events → single flattened event stream
config = TokenizerConfig(
    use_programs=True,      # ← this is the key flag
    one_token_stream=True,  # keeps your single‑stream RNN forward
)

tokenizer = REMI(config)

# Helper set for fast lookup during constrained sampling
PROGRAM_TOKEN_IDS = {
    p: tokenizer[f"Program_{p}"]
    for p in range(128)
    if f"Program_{p}" in tokenizer  # safety for unseen programs
}

ALL_PROGRAM_TOKEN_IDS = set(PROGRAM_TOKEN_IDS.values())

  super().__init__(tokenizer_config, params)


In [5]:
def build_split_lists(root: str, csv: Optional[str] = None):
    """Return lists of MIDI paths grouped by split."""
    if csv:
        meta = pd.read_csv(csv)
        train = meta.loc[meta["split"] == "train", "filepath"].tolist()
        val   = meta.loc[meta["split"] == "validation", "filepath"].tolist()
        test  = meta.loc[meta["split"] == "test", "filepath"].tolist()
    else:
        # Simple 80‑10‑10 random split over *.mid files in the root
        all_midi = [os.path.join(dp, f)
                    for dp, _, files in os.walk(root)
                    for f in files if f.lower().endswith(".mid") or f.lower().endswith(".midi")]
        random.shuffle(all_midi)
        n = len(all_midi)
        train, val, test = (
            all_midi[: int(0.8 * n)],
            all_midi[int(0.8 * n): int(0.9 * n)],
            all_midi[int(0.9 * n):],
        )
    return train, val, test


train_files, val_files, test_files = build_split_lists(DATA_ROOT, SPLIT_CSV)

In [6]:
print(len(train_files))

13782


In [7]:
MIN_LEN = 8          # drop sequences shorter than this after filtering

class FilteredDatasetMIDI(DatasetMIDI):
    def __init__(self, *args, allowed_programs: List[int], **kw):
        super().__init__(*args, **kw)
        self.allowed_program_ids = {
            tokenizer[f"Program_{p}"] for p in allowed_programs
        }

    def _filter_tokens(self, ids: List[int]):
        keep, ok = [], False
        for tid in ids:
            if tid in ALL_PROGRAM_TOKEN_IDS:         # program-change
                ok = tid in self.allowed_program_ids
                if ok:
                    keep.append(tid)
            elif ok:                                 # events of allowed track
                keep.append(tid)
        return keep

    def __getitem__(self, idx):
        """
        Keep resampling indices until we get a dict with a **non-None**
        'input_ids' tensor that remains ≥ MIN_LEN tokens after filtering.
        """
        while True:
            sample = super().__getitem__(idx)               # may be None
            if sample is None or sample.get("input_ids") is None:
                idx = random.randrange(len(self)); continue

            ids = sample["input_ids"]
            if isinstance(ids, torch.Tensor):
                ids = ids.tolist()

            filtered = self._filter_tokens(ids)
            if len(filtered) < MIN_LEN:
                idx = random.randrange(len(self)); continue

            sample["input_ids"] = torch.tensor(filtered, dtype=torch.long)
            return sample


In [8]:
train_dataset = FilteredDatasetMIDI(
    files_paths=train_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    allowed_programs=allowed_programs,
)
val_dataset = FilteredDatasetMIDI(
    files_paths=val_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    allowed_programs=allowed_programs,
)

In [9]:
collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collator)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

In [10]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

In [11]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=10, lr=0.001, device=DEVICE):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        for batch in train_loader:
            batch = batch["input_ids"].to(device)
            inputs, targets = batch[:, :-1], batch[:, 1:]
            optimizer.zero_grad()
            logits, _ = model(inputs)
            loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train = total_train_loss / len(train_loader)

        # ---- validation ----
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch["input_ids"].to(device)
                inputs, targets = batch[:, :-1], batch[:, 1:]
                logits, _ = model(inputs)
                loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
                total_val_loss += loss.item()
        avg_val = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} — train: {avg_train:.4f} | val: {avg_val:.4f}")

In [None]:
vocab = tokenizer.vocab_size
model = MusicRNN(vocab, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS)

print("→ starting training on", len(train_files), "files (multi‑instrument)")
train(model, train_loader, val_loader, vocab)

→ starting training on 13782 files (multi‑instrument)
Epoch 1/10 — train: 1.6422 | val: 1.3758
Epoch 2/10 — train: 1.2645 | val: 1.2165
Epoch 3/10 — train: 1.1526 | val: 1.1582
Epoch 4/10 — train: 1.0847 | val: 1.1211


In [None]:
def sample_with_programs(
    model: nn.Module,
    tokenizer: REMI,
    allowed_programs: List[int],
    max_length: int = 1024,
    temperature: float = 1.0,
    device: str = DEVICE,
):
    """Generate a token sequence that uses ONLY the requested GM program numbers.

    Parameters
    ----------
    allowed_programs : list[int]
        List of MIDI program numbers (0‑127) the piece may contain. Example: [0] for solo
        piano, [40, 41, 42] for a string trio.
    """

    model.to(device)
    model.eval()

    # Build a mask over the vocabulary: 1 for allowed tokens, 0 to ban.
    allow_ids = {PROGRAM_TOKEN_IDS[p] for p in allowed_programs if p in PROGRAM_TOKEN_IDS}
    vocab_size = tokenizer.vocab_size

    # ---- 3. sampler mask: get rid of token_types_indices ----------
    mask = torch.ones(tokenizer.vocab_size, device=device)
    for pid in ALL_PROGRAM_TOKEN_IDS:
        if pid not in {PROGRAM_TOKEN_IDS[p] for p in allowed_programs}:
            mask[pid] = 0.0

    
    # Seed sequence: <BOS> + first Program token (choose first allowed)
    # first_prog = allowed_programs[0]
    generated = [tokenizer["BOS_None"], allowed_programs[0]]

    input_tok = torch.tensor([[generated[-1]]], device=device)
    hidden = None

    for _ in range(max_length):
        logits, hidden = model(input_tok, hidden)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1) * mask  # zero out banned programs
        if probs.sum() == 0:
            # catastrophic: mask wiped all mass, fall back to uniform over allowed
            probs = mask / mask.sum()
        else:
            probs = probs / probs.sum()  # renorm
        next_tok = torch.multinomial(probs, 1).item()
        generated.append(next_tok)
        if next_tok == tokenizer["EOS_None"]:
            break
        input_tok = torch.tensor([[next_tok]], device=device)

    return generated

In [None]:
# ---- sampling demo : string quartet (Violin, Viola, Cello)
#string_programs = [40, 41, 42]  # GM numbers for Violin, Viola, Cello
string_programs = [24, 33]
tokens = sample_with_programs(model, tokenizer, string_programs, max_length=2048*6)
midi = tokenizer.decode(tokens)
midi.dump_midi("multi_RNN.mid")
print("Saved multi_RNN.mid (", len(tokens), "tokens )")

NameError: name 'sample_with_programs' is not defined

In [None]:
pretty_midi = PrettyMIDI("multi_RNN.mid")
print("Duration (seconds):", pretty_midi.get_end_time())
for i, instrument in enumerate(pretty_midi.instruments):
    print(f"{instrument.name or 'Unnamed'}:", len(instrument.notes), "notes")

Duration (seconds): 9.4375
Acoustic Grand Piano: 1 notes
Drums: 100 notes


  Program   Instrument
  0:
  25:
  29:
  48:
  33:
  30:
  27:
  35:
  26:
  28:
  49:
  24:
  32:
   1:
  53:
  65:
  50:
  61:
  73:

In [None]:
midi = PrettyMIDI("multi_RNN.mid")
print("Duration (seconds):", midi.get_end_time())
for i, instrument in enumerate(midi.instruments):
    print(f"{instrument.name or 'Unnamed'}:", len(instrument.notes), "notes")

Duration (seconds): 40.5
String Ensembles 2: 442 notes
Drums: 29 notes
Flute: 28 notes


File Instruments

multi_RNN_11: 
    Duration (seconds): 15.0
    Bright Acoustic Piano: 33 notes
    Electric Bass (finger): 49 notes
    Drums: 79 notes

multi_RNN_12:
    Duration (seconds): 10.9375
    Acoustic Guitar (nylon): 156 notes
    Electric Bass (finger): 23 notes
    Drums: 37 notes

multi_RNN_14:
    Duration (seconds): 15.25
    Electric Bass (finger): 47 notes
    Drums: 141 notes
    Acoustic Guitar (nylon): 57 notes    

multi_RNN_15:
    Duration (seconds): 40.5
    String Ensembles 2: 442 notes
    Drums: 29 notes
    Flute: 28 notes