In [None]:
import os
import random
from typing import List, Optional

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator

from collections import defaultdict
from symusic import Score

from pretty_midi import PrettyMIDI, program_to_instrument_name

import sys
from collections import Counter
from music21 import converter, note, key, stream, meter
import matplotlib.pyplot as plt

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

# ------------ Dataset -------------
# Download from: https://www.kaggle.com/datasets/imsparsh/lakh-midi-clean?resource=download
# Unzip and name dataset folder root "lakh-midi"
DATA_ROOT = "lakh-midi"

# If you have a CSV with train/val/test splits, load it here; otherwise we create one.
SPLIT_CSV = None  # path to optional CSV with columns [filepath, split]

MAX_SEQ_LEN = 1024           # split long pieces into chunks of this many tokens
BATCH_SIZE  = 4
NUM_EPOCHS  = 20
LEARNING_RATE = 0.001

EMBED_DIM  = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2

In [None]:
def collect_program_counts(root: str):
    counts = defaultdict(int)
    midi_files = [os.path.join(dp, f)
                  for dp, _, files in os.walk(root)
                  for f in files if f.lower().endswith((".mid", ".midi"))]
    for fp in midi_files:
        try:
            score = Score(fp)
            for track in score.tracks:
                if track.is_drum:                 # skip channel-10 drums
                    continue
                counts[track.program] += 1
        except Exception:
            continue                              # skip unreadable files
    return counts, midi_files

program_counts, all_midi_files = collect_program_counts(DATA_ROOT)
TOP20 = sorted(program_counts.items(), key=lambda kv: kv[1], reverse=True)[:20]
allowed_programs = [p for p, _ in TOP20]

print("Top-20 instruments:")
print(f"{'GM#':>4}  {'Name':30}  Tracks")
print("-" * 46)
for prog, cnt in TOP20:
    name = program_to_instrument_name(prog)   # e.g. 0 → Acoustic Grand Piano
    print(f"{prog:>4}  {name:30}  {cnt:>6}")

In [None]:
# REMI with Program‑Change events → single flattened event stream
config = TokenizerConfig(
    use_programs=True,      # ← this is the key flag
    one_token_stream=True,  # keeps your single‑stream RNN forward
)

tokenizer = REMI(config)

# Helper set for fast lookup during constrained sampling
PROGRAM_TOKEN_IDS = {
    p: tokenizer[f"Program_{p}"]
    for p in range(128)
    if f"Program_{p}" in tokenizer  # safety for unseen programs
}

ALL_PROGRAM_TOKEN_IDS = set(PROGRAM_TOKEN_IDS.values())

In [None]:
def build_split_lists(root: str, csv: Optional[str] = None):
    """Return lists of MIDI paths grouped by split."""
    if csv:
        meta = pd.read_csv(csv)
        train = meta.loc[meta["split"] == "train", "filepath"].tolist()
        val   = meta.loc[meta["split"] == "validation", "filepath"].tolist()
        test  = meta.loc[meta["split"] == "test", "filepath"].tolist()
    else:
        # Simple 80‑10‑10 random split over *.mid files in the root
        all_midi = [os.path.join(dp, f)
                    for dp, _, files in os.walk(root)
                    for f in files if f.lower().endswith(".mid") or f.lower().endswith(".midi")]
        random.shuffle(all_midi)
        n = len(all_midi)
        train, val, test = (
            all_midi[: int(0.8 * n)],
            all_midi[int(0.8 * n): int(0.9 * n)],
            all_midi[int(0.9 * n):],
        )
    return train, val, test


train_files, val_files, test_files = build_split_lists(DATA_ROOT, SPLIT_CSV)

In [None]:
print(len(train_files))

In [None]:
MIN_LEN = 8          # drop sequences shorter than this after filtering

class FilteredDatasetMIDI(DatasetMIDI):
    def __init__(self, *args, allowed_programs: List[int], **kw):
        super().__init__(*args, **kw)
        self.allowed_program_ids = {
            tokenizer[f"Program_{p}"] for p in allowed_programs
        }

    def _filter_tokens(self, ids: List[int]):
        keep, ok = [], False
        for tid in ids:
            if tid in ALL_PROGRAM_TOKEN_IDS:         # program-change
                ok = tid in self.allowed_program_ids
                if ok:
                    keep.append(tid)
            elif ok:                                 # events of allowed track
                keep.append(tid)
        return keep

    def __getitem__(self, idx):
        """
        Keep resampling indices until we get a dict with a **non-None**
        'input_ids' tensor that remains ≥ MIN_LEN tokens after filtering.
        """
        while True:
            sample = super().__getitem__(idx)               # may be None
            if sample is None or sample.get("input_ids") is None:
                idx = random.randrange(len(self)); continue

            ids = sample["input_ids"]
            if isinstance(ids, torch.Tensor):
                ids = ids.tolist()

            filtered = self._filter_tokens(ids)
            if len(filtered) < MIN_LEN:
                idx = random.randrange(len(self)); continue

            sample["input_ids"] = torch.tensor(filtered, dtype=torch.long)
            return sample


In [None]:
train_dataset = FilteredDatasetMIDI(
    files_paths=train_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    allowed_programs=allowed_programs,
)
val_dataset = FilteredDatasetMIDI(
    files_paths=val_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    allowed_programs=allowed_programs,
)

In [None]:
collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collator)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

In [None]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

In [None]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=10, lr=0.001, device=DEVICE):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        for batch in train_loader:
            batch = batch["input_ids"].to(device)
            inputs, targets = batch[:, :-1], batch[:, 1:]
            optimizer.zero_grad()
            logits, _ = model(inputs)
            loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train = total_train_loss / len(train_loader)

        # ---- validation ----
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch["input_ids"].to(device)
                inputs, targets = batch[:, :-1], batch[:, 1:]
                logits, _ = model(inputs)
                loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
                total_val_loss += loss.item()
        avg_val = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} — train: {avg_train:.4f} | val: {avg_val:.4f}")

In [None]:
vocab = tokenizer.vocab_size
model = MusicRNN(vocab, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS)

print("→ starting training on", len(train_files), "files (multi‑instrument)")
train(model, train_loader, val_loader, vocab, num_epochs=7)

In [None]:
def sample_with_programs(
    model: nn.Module,
    tokenizer: REMI,
    allowed_programs: List[int],
    max_length: int = 1024,
    temperature: float = 1.0,
    device: str = DEVICE,
):
    """Generate a token sequence that uses ONLY the requested GM program numbers.

    Parameters
    ----------
    allowed_programs : list[int]
        List of MIDI program numbers (0‑127) the piece may contain. Example: [0] for solo
        piano, [40, 41, 42] for a string trio.
    """

    model.to(device)
    model.eval()

    # Build a mask over the vocabulary: 1 for allowed tokens, 0 to ban.
    allow_ids = {PROGRAM_TOKEN_IDS[p] for p in allowed_programs if p in PROGRAM_TOKEN_IDS}
    vocab_size = tokenizer.vocab_size

    # ---- 3. sampler mask: get rid of token_types_indices ----------
    mask = torch.ones(tokenizer.vocab_size, device=device)
    for pid in ALL_PROGRAM_TOKEN_IDS:
        if pid not in {PROGRAM_TOKEN_IDS[p] for p in allowed_programs}:
            mask[pid] = 0.0

    
    # Seed sequence: <BOS> + first Program token (choose first allowed)
    # first_prog = allowed_programs[0]
    generated = [tokenizer["BOS_None"], allowed_programs[0]]

    input_tok = torch.tensor([[generated[-1]]], device=device)
    hidden = None

    for _ in range(max_length):
        logits, hidden = model(input_tok, hidden)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1) * mask  # zero out banned programs
        if probs.sum() == 0:
            # catastrophic: mask wiped all mass, fall back to uniform over allowed
            probs = mask / mask.sum()
        else:
            probs = probs / probs.sum()  # renorm
        next_tok = torch.multinomial(probs, 1).item()
        generated.append(next_tok)
        if next_tok == tokenizer["EOS_None"]:
            break
        input_tok = torch.tensor([[next_tok]], device=device)

    return generated

In [None]:
def sample(
    model: nn.Module,
    tokenizer: REMI,
    allowed_programs: List[int],
    max_length: int = 1024,
    section_len: int = 256,      # distance between forced program switches
    temperature: float = 1.0,
    device: str = DEVICE,
):
    """Guaranteed multi-instrument generator.

    The sequence is partitioned into sections of `section_len` tokens.
    At the start of each section we *force-insert* the next unused Program
    token from `allowed_programs`.  All other Program tokens remain masked
    out, so the model can’t override the schedule.
    """
    model.to(device); model.eval()

    # vocab mask that blocks EVERY Program token for now
    global_mask = torch.ones(tokenizer.vocab_size, device=device)
    for pid in ALL_PROGRAM_TOKEN_IDS:
        global_mask[pid] = 0.0

    # list of Program ids we will inject
    prog_queue = [PROGRAM_TOKEN_IDS[p] for p in allowed_programs]
    active_prog_id = prog_queue.pop(0)          # first program is active

    generated = [tokenizer["BOS_None"], active_prog_id]
    mask = global_mask.clone()
    mask[active_prog_id] = 1.0                  # allow notes for current program

    input_tok = torch.tensor([[active_prog_id]], device=device)
    hidden = None

    for t in range(2, max_length):

        # --- force a new instrument at section boundaries ----------
        if (t % section_len == 0) and prog_queue:
            active_prog_id = prog_queue.pop(0)
            generated.append(active_prog_id)
            mask = global_mask.clone()
            mask[active_prog_id] = 1.0
            input_tok = torch.tensor([[active_prog_id]], device=device)
            hidden = None                            # reset hidden to avoid shock
            continue

        # --- normal token sampling --------------------------------
        logits, hidden = model(input_tok, hidden)
        probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1) * mask
        probs = probs / probs.sum()
        next_tok = torch.multinomial(probs, 1).item()

        generated.append(next_tok)
        if next_tok == tokenizer["EOS_None"]:
            break

        input_tok = torch.tensor([[next_tok]], device=device)

    return generated


In [None]:
# ---- sampling demo : string quartet (Violin, Viola, Cello)
#string_programs = [40, 41, 42]  # GM numbers for Violin, Viola, Cello
string_programs = [25, 33]
tokens = sample(model, tokenizer, string_programs, max_length=2048*10)
midi = tokenizer.decode(tokens)
midi.dump_midi("multi_RNN.mid")
print("Saved multi_RNN.mid (", len(tokens), "tokens )")

In [None]:
pretty_midi = PrettyMIDI("multi_RNN.mid")
print("Duration (seconds):", pretty_midi.get_end_time())
for i, instrument in enumerate(pretty_midi.instruments):
    print(f"{instrument.name or 'Unnamed'}:", len(instrument.notes), "notes")

Top-20 instruments:

 GM#  Name                            Tracks

----------------------------------------------

   0  Acoustic Grand Pian

  25  Acoustic Guitar (steel)

  29  Overdriven Guitar

  48  String Ensemble 1

  33  Electric Bass (finger)

  30  Distortion Guitar

  27  Electric Guitar (clean)
  
  35  Fretless Bass

  52  Choir Aahs

  26  Electric Guitar (jazz)

  28  Electric Guitar (muted)

  49  String Ensemble 2

  24  Acoustic Guitar (nylon)

  32  Acoustic Bass

   1  Bright Acoustic Piano

  53  Voice Oohs

  65  Alto Sax

  50  Synth Strings 1

  61  Brass Section

  73  Flute



File Instruments

multi_RNN_11: 
    Duration (seconds): 15.0
    Bright Acoustic Piano: 33 notes
    Electric Bass (finger): 49 notes
    Drums: 79 notes

multi_RNN_12:
    Duration (seconds): 10.9375
    Acoustic Guitar (nylon): 156 notes
    Electric Bass (finger): 23 notes
    Drums: 37 notes

multi_RNN_14:
    Duration (seconds): 15.25
    Electric Bass (finger): 47 notes
    Drums: 141 notes
    Acoustic Guitar (nylon): 57 notes    

multi_RNN_15:
    Duration (seconds): 40.5
    String Ensembles 2: 442 notes
    Drums: 29 notes
    Flute: 28 notes

16
Duration (seconds): 21.1875
Acoustic Grand Piano: 1 notes
Acoustic Guitar (nylon): 12 notes
Drums: 161 notes

17
Duration (seconds): 70.25
Acoustic Guitar (nylon): 57 notes
Drums: 447 notes
Electric Bass (finger): 13 notes
Choir Aahs: 55 notes

18
distortion guitar
alto sax
string ensemble

19
Duration (seconds): 24.5
Acoustic Bass: 14 notes
Drums: 70 notes
Brass Section: 34 notes
Acoustic Grand Piano: 163 notes

20
Duration (seconds): 39.75
Acoustic Grand Piano: 54 notes
Acoustic Guitar (steel): 53 notes
Drums: 13 notes
String Ensembles 1: 55 notes
Choir Aahs: 45 notes
Acoustic Bass: 27 notes

22
Duration (seconds): 20.25
SynthStrings 1: 63 notes
Acoustic Bass: 95 notes
Drums: 135 notes

23
Duration (seconds): 32.3125
Electric Guitar (jazz): 48 notes
Drums: 31 notes
Alto Sax: 27 notes
Bright Acoustic Piano: 234 notes

24
Duration (seconds): 20.75
Electric Guitar (jazz): 43 notes
Drums: 8 notes
Alto Sax: 9 notes

25
Duration (seconds): 16.125
Electric Guitar (jazz): 57 notes
Drums: 194 notes
Alto Sax: 157 notes

26
Duration (seconds): 34.25
Acoustic Guitar (steel): 54 notes
Electric Bass (finger): 78 notes
Drums: 236 notes

In [None]:
def random_baseline_sample(tokenizer, allowed_programs, seq_len=1024):
    """
    Uniform-random baseline that never strays outside `allowed_programs`.
    """
    rng = random.Random()

    # ids for Program tokens we allow
    allow_prog_ids = {PROGRAM_TOKEN_IDS[p] for p in allowed_programs}

    # every vocab id except ALL program tokens
    legal_ids = [tid for tid in range(len(tokenizer))
                 if tid not in ALL_PROGRAM_TOKEN_IDS]

    # also permit the allowed Program tokens themselves (so a second track may appear)
    legal_ids += list(allow_prog_ids)

    out = [tokenizer["BOS_None"], PROGRAM_TOKEN_IDS[allowed_programs[0]]]

    for _ in range(seq_len - 3):  # reserve one slot for EOS
        out.append(rng.choice(legal_ids))

    out.append(tokenizer["EOS_None"])
    return out


In [None]:
tokens = random_baseline_sample(tokenizer, allowed_programs=[25, 33], seq_len=2048)

midi = tokenizer.decode(tokens)
midi.dump_midi("multi_random_2.mid")


In [None]:
pretty_midi = PrettyMIDI("multi_random_2.mid")
print("Duration (seconds):", pretty_midi.get_end_time())
for i, instrument in enumerate(pretty_midi.instruments):
    print(f"{instrument.name or 'Unnamed'}:", len(instrument.notes), "notes")