### Task 1

Training models to generate midi sequences of piano music


In [None]:
# Import necessary libraries
import glob
import random
from typing import List
from collections import defaultdict

import os
import pandas as pd

import numpy as np
from numpy.random import choice

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from symusic import Score
from miditok import REMI, TokenizerConfig, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator

from midi2audio import FluidSynth # Import library
from IPython.display import Audio, display

from pretty_midi import PrettyMIDI

import random
from mido import Message, MidiFile, MidiTrack


In [None]:
# Uses 'cuda' if a gpu is detected. Otherwise uses cpu
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Can also set manually
#DEVICE = 'cpu'
#DEVICE = 'cuda'

print(DEVICE)

In [None]:
# Load metadata and MIDI file paths from MAESTRO dataset

ROOT = "maestro-v3.0.0"            # change if you unpacked elsewhere
meta = pd.read_csv(os.path.join(ROOT, "maestro-v3.0.0.csv"))

def list_midi_files(split):
    paths = meta.loc[meta["split"] == split, "midi_filename"]
    return [os.path.join(ROOT, p) for p in paths]

train_files = list_midi_files("train")        # 962 MIDI files
val_files   = list_midi_files("validation")   # 137
test_files  = list_midi_files("test")         # 177


In [None]:
# Type validity
type(train_files[0])
train_files[0].encode('utf-8').decode('utf-8')
print(train_files[0].encode('utf-8'))
str.encode(train_files[0], 'utf-8')

Tokenizer

In [None]:
# Define tokenizer parameters
TOKENIZER_PARAMS = {
    "use_chords": True,
    "use_tempos": True,
    "use_time_signatures": True,
    "use_key_signatures": True,
}

# Create the tokenizer configuration
config = TokenizerConfig(**TOKENIZER_PARAMS)

# Initialize the REMI tokenizer with the configuration
tokenizer = REMI(config)

# Create datasets and data loaders
train_dataset = DatasetMIDI(
    files_paths=train_files,
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
test_dataset = DatasetMIDI(
    files_paths=test_files,
    tokenizer=tokenizer,
    max_seq_len=1024,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collator)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=collator)

In [None]:
len(train_loader), len(test_loader)

Transformer

In [None]:
# Define our basic Transformer model for music generation
class MusicTransformer(nn.Module):
    def __init__(self, vocab_size, embedding_dim=256, num_heads=8, num_layers=6, dropout=0.1, max_seq_len=1024):
        super(MusicTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.pos_encoder = nn.Parameter(self._generate_positional_encoding(max_seq_len, embedding_dim), requires_grad=False)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embedding_dim,
            nhead=num_heads,
            dim_feedforward=embedding_dim * 4,
            dropout=dropout,
            batch_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        # x: (batch_size, seq_len)
        x = self.embedding(x) + self.pos_encoder[:, :x.size(1), :]
        x = self.transformer_encoder(x)
        return self.fc_out(x)

    def _generate_positional_encoding(self, max_len, d_model):
        """Creates sinusoidal positional encoding matrix"""
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0)  # shape: (1, max_len, d_model)


Training

In [None]:
# Train the model on the dataset
def train(model, train_loader, val_loader, vocab_size, num_epochs=20, lr=0.001, device=DEVICE):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # --------- Training ---------
        model.train()
        total_train_loss = 0

        for batch in train_loader:
            batch = batch['input_ids'].to(device)  # (batch_size, seq_length)

            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            optimizer.zero_grad()
            outputs = model(inputs)
            outputs = outputs.reshape(-1, vocab_size)
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # --------- Validation ---------
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch['input_ids'].to(device)

                inputs = batch[:, :-1]
                targets = batch[:, 1:]

                outputs = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)

                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

In [None]:
# Instantiate and train the model
vocab_size = tokenizer.vocab_size
embedding_dim = 256
hidden_dim = 512
num_layers = 2

model = MusicTransformer(vocab_size, embedding_dim=256, num_heads=8, num_layers=6)
train(model, train_loader, test_loader, vocab_size)

Sampling

In [None]:
# Sampling: Generate new music from the trained model
def sample(model, start_token, tokenizer, max_length=512, temperature=1.0, device=DEVICE):
    model.eval()

    # Build ID → string mapping
    if hasattr(tokenizer, 'vocab') and isinstance(tokenizer.vocab, dict):
        id_to_token = {v: k for k, v in tokenizer.vocab.items()}
    elif hasattr(tokenizer, '_vocab'):
        id_to_token = {i: tok for i, tok in enumerate(tokenizer._vocab)}
    else:
        raise RuntimeError("Tokenizer vocab not found")

    generated = [start_token]
    input_seq = torch.tensor([generated], dtype=torch.long, device=device)

    while len(generated) < max_length:
        with torch.no_grad():
            logits = model(input_seq)
            next_logits = logits[0, -1] / temperature
            probs = F.softmax(next_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()

        token_str = id_to_token.get(next_token, "")

        # Always add bar/position/timeshift
        if token_str.startswith(("Bar", "TimeShift", "Position")):
            generated.append(next_token)

        # If it's a pitch, follow with velocity and duration
        elif token_str.startswith("Pitch_"):
            generated.append(next_token)

            # Sample a Velocity token
            velocity_ids = [i for i, tok in id_to_token.items() if tok.startswith("Velocity_")]
            generated.append(random.choice(velocity_ids))

            # Sample a Duration token
            duration_ids = [i for i, tok in id_to_token.items() if tok.startswith("Duration_")]
            generated.append(random.choice(duration_ids))

        # Stop on EOS or PAD
        if token_str in ("EOS_None"):
            break

        input_seq = torch.tensor([generated], dtype=torch.long, device=device)

    return generated


In [None]:
# Generate and save a sample MIDI file
start_token = tokenizer.special_tokens_ids[1]
generated_sequence = sample(model, start_token, tokenizer, max_length=1024)

print("Generated token sequence:")
print(generated_sequence)
print("num tokens:", len(generated_sequence))

In [None]:
output_score = tokenizer.decode([generated_sequence])
output_score.dump_midi(f"transformer.mid")

In [None]:
# Inspect generated MIDI content
pretty_midi = PrettyMIDI("transformer.mid")
print("Duration (seconds):", pretty_midi.get_end_time())
for i, instrument in enumerate(pretty_midi.instruments):
    print(f"{instrument.name or 'Unnamed'}:", len(instrument.notes), "notes")

Random Baseline

In [None]:
# Randomly generate a midifile as a baseline
def generate_random_midi(filename="random.mid", 
                         num_notes=50, 
                         pitch_range=(21, 108), 
                         velocity_range=(30, 100), 
                         duration_range=(120, 480)):
    mid = MidiFile()
    track = MidiTrack()
    mid.tracks.append(track)

    time = 0
    for _ in range(num_notes):
        pitch = random.randint(*pitch_range)
        velocity = random.randint(*velocity_range)
        duration = random.randint(*duration_range)

        # Note on
        track.append(Message('note_on', note=pitch, velocity=velocity, time=time))
        # Note off
        track.append(Message('note_off', note=pitch, velocity=0, time=duration))

        time = 0  # subsequent notes start right after the previous ends

    mid.save(filename)
    print(f"Saved: {filename}")

# Generate an example
generate_random_midi()

Evaluate the generated midi file against the randomly generated file

In [None]:
# Plot notes by beat to see if model learned timings
def plot_notes_by_beat(beat_note_counts):
    import matplotlib.pyplot as plt

    # Organize data
    beats = sorted(beat_note_counts.keys())
    all_pitches = sorted({p for counter in beat_note_counts.values() for p in counter})

    plt.figure(figsize=(12, 6))

    # Plot each pitch as scatter points
    for pitch in all_pitches:
        x = []
        y = []
        sizes = []
        for beat in beats:
            count = beat_note_counts[beat].get(pitch, 0)
            if count > 0:
                x.append(beat)
                y.append(pitch)
                sizes.append(count * 20)  # Adjust marker size scaling as needed

        plt.scatter(x, y, s=sizes, alpha=0.6, label=pitch)

    plt.title("Note Occurrences by Beat Position")
    plt.xlabel("Beat Number in Measure")
    plt.ylabel("Pitch")
    plt.legend(loc="upper right", fontsize="small", ncol=2)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

    

# Pliot note frequencies
def plot_note_frequencies(note_list, detected_key):
    note_counts = Counter(note_list)
    labels, counts = zip(*sorted(note_counts.items(), key=lambda x: x[0]))

    plt.figure(figsize=(10, 6))
    bars = plt.bar(labels, counts, color='skyblue', edgecolor='black')

    # Highlight notes in the key
    in_key = set(p.name for p in detected_key.pitches)
    for bar, label in zip(bars, labels):
        if label not in in_key:
            bar.set_color('salmon')

    plt.title(f"Note Occurrences in MIDI (Key: {detected_key})")
    plt.xlabel("Note Name (no octave)")
    plt.ylabel("Frequency")
    plt.grid(axis='y', linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()



In [None]:
# Analyze midi files against each other
def analyze_midi_key_compliance(file_path):
    # Parse the MIDI file
    try:
        midi_stream = converter.parse(file_path)
    except Exception as e:
        print(f"❌ Failed to parse MIDI: {e}")
        return

    # Detect key
    try:
        detected_key = midi_stream.analyze('key')
    except Exception as e:
        print(f"❌ Failed to detect key: {e}")
        return

    allowed_notes = set(p.name for p in detected_key.pitches)

    print(f"\n🎼 Detected Key: {detected_key}")
    print(f"🎵 Allowed Notes: {sorted(allowed_notes)}\n")

    note_names = []
    non_conforming = []

    # Collect note data
    for element in midi_stream.recurse():
        if isinstance(element, note.Note):
            name = element.pitch.name  # e.g. "C#", "A"
            note_names.append(name)
            if name not in allowed_notes:
                non_conforming.append((name, element.offset))

    total_notes = len(note_names)

    # Display compliance results
    if non_conforming:
        print(f"❌ Found {len(non_conforming)} non-conforming notes (out of {total_notes}):")
        for pitch, offset in non_conforming:
            print(f" - {pitch} at offset {offset}")
    else:
        print("✅ All notes conform to the detected key.")

    # Plot note frequency
    plot_note_frequencies(note_names, detected_key)

    note_names = []
    non_conforming = []
    beat_note_counts = defaultdict(Counter)
    time_signatures = []
    velocity_per_beat = defaultdict(list)
    velocity_counts_per_beat = defaultdict(Counter)



    for element in midi_stream.recurse():
        if isinstance(element, meter.TimeSignature):
            time_signatures.append((element.ratioString, element.offset))

        if isinstance(element, note.Note):
            name = element.pitch.name
            beat = round(element.beat, 2)  # round for grouping
            note_names.append(name)

            if name not in allowed_notes:
                non_conforming.append((name, element.offset))

            beat_note_counts[beat][name] += 1
            velocity_per_beat[beat].append(element.volume.velocity or 64)  # MIDI default = 64
            velocity_counts_per_beat[beat][element.volume.velocity] += 1

    # Print time signature info
    print("\n🕐 Time Signature(s) detected:")
    for ts, offset in time_signatures:
        print(f" - {ts} at offset {offset}")

    # ... [existing compliance and plotting code]

    print("\n📊 Note occurrences per beat position (aggregated):")
    for beat in sorted(beat_note_counts):
        print(f" Beat {beat}:")
        for pitch, count in beat_note_counts[beat].items():
            print(f"   - {pitch}: {count}")

    plot_notes_by_beat(beat_note_counts)

In [None]:
# Evaluate a random sample vs a sample from the LSTM

#rand_file_path = "random/filepath.mid"
#analyze_midi_key_compliance(rand_file_path)
#gen_file_path = "generated/filepath.mid"
#analyze_midi_key_compliance(gen_file_path)

### Task 2

Training models to generate
midi sequences of music
conditioned on instruments


In [None]:
# Import necessary libraries
import os
import random
from typing import List, Optional

import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator

from collections import defaultdict
from symusic import Score

from pretty_midi import PrettyMIDI, program_to_instrument_name

import sys
from collections import Counter
from music21 import converter, note, key, stream, meter
import matplotlib.pyplot as plt

from music21 import converter, note, instrument
import matplotlib.pyplot as plt
import statistics

In [None]:
# Use gpu if found, otherwise use cpu
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)

# ------------ Dataset -------------
# Download from: https://www.kaggle.com/datasets/imsparsh/lakh-midi-clean?resource=download
# Unzip and name dataset folder root "lakh-midi"
DATA_ROOT = "lakh-midi"

# If you have a CSV with train/val/test splits, load it here; otherwise we create one.
SPLIT_CSV = None  # path to optional CSV with columns [filepath, split]

MAX_SEQ_LEN = 1024           # split long pieces into chunks of this many tokens
BATCH_SIZE  = 4
NUM_EPOCHS  = 20
LEARNING_RATE = 0.001

EMBED_DIM  = 256
HIDDEN_DIM = 512
NUM_LAYERS = 2

In [None]:
# scan dataset to find top‑20 instruments
def collect_program_counts(root: str):
    counts = defaultdict(int)
    midi_files = [os.path.join(dp, f)
                  for dp, _, files in os.walk(root)
                  for f in files if f.lower().endswith((".mid", ".midi"))]
    for fp in midi_files:
        try:
            score = Score(fp)
            for track in score.tracks:
                if track.is_drum:                 # skip channel-10 drums
                    continue
                counts[track.program] += 1
        except Exception:
            continue                              # skip unreadable files
    return counts, midi_files

program_counts, all_midi_files = collect_program_counts(DATA_ROOT)
TOP20 = sorted(program_counts.items(), key=lambda kv: kv[1], reverse=True)[:20]
allowed_programs = [p for p, _ in TOP20]

print("Top-20 instruments:")
print(f"{'GM#':>4}  {'Name':30}  Tracks")
print("-" * 46)
for prog, cnt in TOP20:
    name = program_to_instrument_name(prog)   # e.g. 0 → Acoustic Grand Piano
    print(f"{prog:>4}  {name:30}  {cnt:>6}")

In [None]:
# REMI with Program‑Change events → single flattened event stream
config = TokenizerConfig(
    use_programs=True,      # ← this is the key flag
    one_token_stream=True,  # keeps your single‑stream RNN forward
)

tokenizer = REMI(config)

# Helper set for fast lookup during constrained sampling
PROGRAM_TOKEN_IDS = {
    p: tokenizer[f"Program_{p}"]
    for p in range(128)
    if f"Program_{p}" in tokenizer  # safety for unseen programs
}

ALL_PROGRAM_TOKEN_IDS = set(PROGRAM_TOKEN_IDS.values())

In [None]:
# Create train and testsets for training and evaluation
def build_split_lists(root: str, csv: Optional[str] = None):
    """Return lists of MIDI paths grouped by split."""
    if csv:
        meta = pd.read_csv(csv)
        train = meta.loc[meta["split"] == "train", "filepath"].tolist()
        val   = meta.loc[meta["split"] == "validation", "filepath"].tolist()
        test  = meta.loc[meta["split"] == "test", "filepath"].tolist()
    else:
        # Simple 80‑10‑10 random split over *.mid files in the root
        all_midi = [os.path.join(dp, f)
                    for dp, _, files in os.walk(root)
                    for f in files if f.lower().endswith(".mid") or f.lower().endswith(".midi")]
        random.shuffle(all_midi)
        n = len(all_midi)
        train, val, test = (
            all_midi[: int(0.8 * n)],
            all_midi[int(0.8 * n): int(0.9 * n)],
            all_midi[int(0.9 * n):],
        )
    return train, val, test


train_files, val_files, test_files = build_split_lists(DATA_ROOT, SPLIT_CSV)

In [None]:
MIN_LEN = 8          # drop sequences shorter than this after filtering

# Create dataset class which can do instrument filtering on the fly
class FilteredDatasetMIDI(DatasetMIDI):
    def __init__(self, *args, allowed_programs: List[int], **kw):
        super().__init__(*args, **kw)
        self.allowed_program_ids = {
            tokenizer[f"Program_{p}"] for p in allowed_programs
        }

    def _filter_tokens(self, ids: List[int]):
        keep, ok = [], False
        for tid in ids:
            if tid in ALL_PROGRAM_TOKEN_IDS:         # program-change
                ok = tid in self.allowed_program_ids
                if ok:
                    keep.append(tid)
            elif ok:                                 # events of allowed track
                keep.append(tid)
        return keep

    def __getitem__(self, idx):
        """
        Keep resampling indices until we get a dict with a **non-None**
        'input_ids' tensor that remains ≥ MIN_LEN tokens after filtering.
        """
        while True:
            sample = super().__getitem__(idx)               # may be None
            if sample is None or sample.get("input_ids") is None:
                idx = random.randrange(len(self)); continue

            ids = sample["input_ids"]
            if isinstance(ids, torch.Tensor):
                ids = ids.tolist()

            filtered = self._filter_tokens(ids)
            if len(filtered) < MIN_LEN:
                idx = random.randrange(len(self)); continue

            sample["input_ids"] = torch.tensor(filtered, dtype=torch.long)
            return sample


In [None]:
# Create final train and validation datasets
train_dataset = FilteredDatasetMIDI(
    files_paths=train_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    allowed_programs=allowed_programs,
)
val_dataset = FilteredDatasetMIDI(
    files_paths=val_files,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
    allowed_programs=allowed_programs,
)
collator = DataCollator(tokenizer.pad_token_id)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=collator)
val_loader   = DataLoader(val_dataset,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)

LSTM

In [None]:
# 2-Layer LSTM Impolementation
class MusicRNN(nn.Module):
    def __init__(self, vocab_size: int, embedding_dim: int, hidden_dim: int, num_layers: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        x = self.embedding(x)
        out, hidden = self.rnn(x, hidden)
        out = self.fc(out)
        return out, hidden

In [None]:
# Training loop to train multi instrument LSTM
def train(model, train_loader, val_loader, vocab_size, num_epochs=10, lr=0.001, device=DEVICE):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # ---- training ----
        model.train()
        total_train_loss = 0.0
        for batch in train_loader:
            batch = batch["input_ids"].to(device)
            inputs, targets = batch[:, :-1], batch[:, 1:]
            optimizer.zero_grad()
            logits, _ = model(inputs)
            loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item()
        avg_train = total_train_loss / len(train_loader)

        # ---- validation ----
        model.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                batch = batch["input_ids"].to(device)
                inputs, targets = batch[:, :-1], batch[:, 1:]
                logits, _ = model(inputs)
                loss = criterion(logits.reshape(-1, vocab_size), targets.reshape(-1))
                total_val_loss += loss.item()
        avg_val = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} — train: {avg_train:.4f} | val: {avg_val:.4f}")

In [None]:
# Train LSTM
vocab = tokenizer.vocab_size
model = MusicRNN(vocab, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS)

print("→ starting training on", len(train_files), "files (multi‑instrument)")
train(model, train_loader, val_loader, vocab, num_epochs=20)

Music Generation (Sampling from LSTM)

In [None]:
# First Smapling method, gives model complete freedom on token generation
def sample_free(
    model: nn.Module,
    tokenizer: REMI,
    allowed_programs: List[int],
    max_length: int = 1024,
    temperature: float = 1.0,
    device: str = DEVICE,
):
    """Generate a token sequence that uses ONLY the requested GM program numbers.

    Parameters
    ----------
    allowed_programs : list[int]
        List of MIDI program numbers (0‑127) the piece may contain. Example: [0] for solo
        piano, [40, 41, 42] for a string trio.
    """

    model.to(device)
    model.eval()

    # Build a mask over the vocabulary: 1 for allowed tokens, 0 to ban.
    allow_ids = {PROGRAM_TOKEN_IDS[p] for p in allowed_programs if p in PROGRAM_TOKEN_IDS}
    vocab_size = tokenizer.vocab_size

    # ---- 3. sampler mask: get rid of token_types_indices ----------
    mask = torch.ones(tokenizer.vocab_size, device=device)
    for pid in ALL_PROGRAM_TOKEN_IDS:
        if pid not in {PROGRAM_TOKEN_IDS[p] for p in allowed_programs}:
            mask[pid] = 0.0

    
    # Seed sequence: <BOS> + first Program token (choose first allowed)
    # first_prog = allowed_programs[0]
    generated = [tokenizer["BOS_None"], allowed_programs[0]]

    input_tok = torch.tensor([[generated[-1]]], device=device)
    hidden = None

    for _ in range(max_length):
        logits, hidden = model(input_tok, hidden)
        logits = logits[:, -1, :] / temperature
        probs = F.softmax(logits, dim=-1) * mask  # zero out banned programs
        if probs.sum() == 0:
            # catastrophic: mask wiped all mass, fall back to uniform over allowed
            probs = mask / mask.sum()
        else:
            probs = probs / probs.sum()  # renorm
        next_tok = torch.multinomial(probs, 1).item()
        generated.append(next_tok)
        if next_tok == tokenizer["EOS_None"]:
            break
        input_tok = torch.tensor([[next_tok]], device=device)

    return generated

In [None]:
# Second Smapling method, model still has freedom but is insentivized to use varying instruments
def sample_weighted(
    model: nn.Module,
    tokenizer: REMI,
    allowed_programs: List[int],
    max_length: int = 1024,
    section_len: int = 256,      # distance between forced program switches
    temperature: float = 1.0,
    device: str = DEVICE,
):
    """Guaranteed multi-instrument generator.

    The sequence is partitioned into sections of `section_len` tokens.
    At the start of each section we *force-insert* the next unused Program
    token from `allowed_programs`.  All other Program tokens remain masked
    out, so the model can’t override the schedule.
    """
    model.to(device); model.eval()

    # vocab mask that blocks EVERY Program token for now
    global_mask = torch.ones(tokenizer.vocab_size, device=device)
    for pid in ALL_PROGRAM_TOKEN_IDS:
        global_mask[pid] = 0.0

    # list of Program ids we will inject
    prog_queue = [PROGRAM_TOKEN_IDS[p] for p in allowed_programs]
    active_prog_id = prog_queue.pop(0)          # first program is active

    generated = [tokenizer["BOS_None"], active_prog_id]
    mask = global_mask.clone()
    mask[active_prog_id] = 1.0                  # allow notes for current program

    input_tok = torch.tensor([[active_prog_id]], device=device)
    hidden = None

    for t in range(2, max_length):

        # --- force a new instrument at section boundaries ----------
        if (t % section_len == 0) and prog_queue:
            active_prog_id = prog_queue.pop(0)
            generated.append(active_prog_id)
            mask = global_mask.clone()
            mask[active_prog_id] = 1.0
            input_tok = torch.tensor([[active_prog_id]], device=device)
            hidden = None                            # reset hidden to avoid shock
            continue

        # --- normal token sampling --------------------------------
        logits, hidden = model(input_tok, hidden)
        probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1) * mask
        probs = probs / probs.sum()
        next_tok = torch.multinomial(probs, 1).item()

        generated.append(next_tok)
        if next_tok == tokenizer["EOS_None"]:
            break

        input_tok = torch.tensor([[next_tok]], device=device)

    return generated


In [None]:
# ---- sampling demo : string quartet (Violin, Viola, Cello)
string_programs = [40, 41, 42]  # GM numbers for Violin, Viola, Cello
# NOTE: string programs should be program numbers that are in the top 20 instruments (model is trained on)
tokens = sample(model, tokenizer, string_programs, max_length=2048*10)
midi = tokenizer.decode(tokens)
midi.dump_midi("multi_RNN.mid")
print("Saved multi_RNN.mid (", len(tokens), "tokens )")

In [None]:
# Inspect generated MIDI content
pretty_midi = PrettyMIDI("multi_RNN.mid")
print("Duration (seconds):", pretty_midi.get_end_time())
for i, instrument in enumerate(pretty_midi.instruments):
    print(f"{instrument.name or 'Unnamed'}:", len(instrument.notes), "notes")

Random Baseline

In [None]:
# Generate a random midi file across multiple instruments
def random_baseline_sample(tokenizer, allowed_programs, seq_len=1024):
    """
    Uniform-random baseline that never strays outside `allowed_programs`.
    """
    rng = random.Random()

    # ids for Program tokens we allow
    allow_prog_ids = {PROGRAM_TOKEN_IDS[p] for p in allowed_programs}

    # every vocab id except ALL program tokens
    legal_ids = [tid for tid in range(len(tokenizer))
                 if tid not in ALL_PROGRAM_TOKEN_IDS]

    # also permit the allowed Program tokens themselves (so a second track may appear)
    legal_ids += list(allow_prog_ids)

    out = [tokenizer["BOS_None"], PROGRAM_TOKEN_IDS[allowed_programs[0]]]

    for _ in range(seq_len - 3):  # reserve one slot for EOS
        out.append(rng.choice(legal_ids))

    out.append(tokenizer["EOS_None"])
    return out


In [None]:
# Create random multi-instrument sample
tokens = random_baseline_sample(tokenizer, allowed_programs=[25, 33], seq_len=2048)

midi = tokenizer.decode(tokens)
midi.dump_midi("multi_random_2.mid")


Evaluation

In [None]:
# Analysis Function
def compute_instrument_metrics(midi_stream, label):
    results = []

    for i, part in enumerate(midi_stream.parts):
        instr = part.getInstrument(returnDefault=True)
        name = instr.instrumentName or f"Instrument {int((i+1)/2 + 1/2)}"

        if isinstance(instr, instrument.Percussion) or 'drum' in name.lower():
            continue

        try:
            detected_key = part.analyze('key')
        except Exception:
            continue

        allowed_notes = set(p.name for p in detected_key.pitches)

        total_notes = 0
        nonconforming = 0
        velocities = []
        pitches = []

        for el in part.recurse():
            if isinstance(el, note.Note):
                total_notes += 1
                pitches.append(el.pitch.midi)
                velocities.append(el.volume.velocity or 64)
                if el.pitch.name not in allowed_notes:
                    nonconforming += 1

        if total_notes == 0:
            continue

        key_compliance = 100 * (1 - nonconforming / total_notes)
        avg_velocity = statistics.mean(velocities) if velocities else 64
        pitch_range = max(pitches) - min(pitches) if len(pitches) >= 2 else 0

        results.append({
            "instrument": name,
            "label": label,
            "key_compliance": key_compliance,
            "avg_velocity": avg_velocity,
            "pitch_range": pitch_range
        })

    return results

# --- Plotting Function ---

def plot_grouped_bar_chart(data, metric_name, title, ylabel):
    instruments = sorted(set(d["instrument"] for d in data))
    models = sorted(set(d["label"] for d in data))

    metric_data = {model: [] for model in models}
    for instr in instruments:
        for model in models:
            match = next((d for d in data if d["instrument"] == instr and d["label"] == model), None)
            value = match[metric_name] if match else 0
            metric_data[model].append(value)

    x = range(len(instruments))
    width = 0.35

    fig, ax = plt.subplots(figsize=(12, 6))
    for i, model in enumerate(models):
        ax.bar([p + i * width for p in x], metric_data[model], width=width, label=model)

    ax.set_xticks([p + width * (len(models) / 2 - 0.5) for p in x])
    ax.set_xticklabels(instruments, rotation=45)
    ax.set_title(title)
    ax.set_ylabel(ylabel)
    ax.legend()
    ax.grid(axis='y', linestyle='--', alpha=0.6)
    plt.tight_layout()
    plt.show()

# --- Driver Function ---

def compare_midi_models(model_file_path, random_file_path, model_label="LSTM", baseline_label="Random"):
    try:
        model_stream = converter.parse(model_file_path)
        random_stream = converter.parse(random_file_path)
    except Exception as e:
        print(f"❌ Error loading MIDI files: {e}")
        return

    model_metrics = compute_instrument_metrics(model_stream, model_label)
    random_metrics = compute_instrument_metrics(random_stream, baseline_label)
    combined = model_metrics + random_metrics

    #plot_grouped_bar_chart(combined, "key_compliance", "Key Compliance by Instrument", "Key Compliance (%)")
    #plot_grouped_bar_chart(combined, "avg_velocity", "Average Velocity by Instrument", "Average Velocity")
    plot_grouped_bar_chart(combined, "pitch_range", "Pitch Range by Instrument", "Pitch Range (semitones)")


In [None]:
# Example Evaluation Usage

model_file = "generated-multi-instrument-music/multi.mid"
random_file = "generated-random-instrument-music/random_2.mid"
compare_midi_models(model_file, random_file)