# MIDI Music Generation

Generate music using the trained MusicGPT model and SkyTNT tokenizer.

**Requirements:** Trained model at `midi_data/best_model_lm.pt`

## 1. Setup

In [1]:
import random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import Audio
import matplotlib.pyplot as plt

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Device: {DEVICE}")

DATA_DIR = Path("./midi_data")
OUTPUT_DIR = DATA_DIR / "generated"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

Device: mps


## 2. Load Model & Tokenizer

In [2]:
from midi_tokenizer import MIDITokenizer

tokenizer = MIDITokenizer()

VOCAB_SIZE = tokenizer.vocab_size
PAD_ID = tokenizer.pad_id
BOS_ID = tokenizer.bos_id
EOS_ID = tokenizer.eos_id

print(f"Vocabulary: {VOCAB_SIZE} tokens")
print(f"Special: PAD={PAD_ID}, BOS={BOS_ID}, EOS={EOS_ID}")

Vocabulary: 3406 tokens
Special: PAD=0, BOS=1, EOS=2


In [3]:
# Model definition (same as training)
class MusicGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, 
                 max_seq_len=1024, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)
        
        self.register_buffer(
            'causal_mask',
            torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        )
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.token_emb(x) + self.pos_emb(positions)
        x = self.dropout(x)
        mask = self.causal_mask[:seq_len, :seq_len]
        x = self.transformer(x, x, tgt_mask=mask, memory_mask=mask)
        return self.lm_head(x)

# Load model
MODEL_PATH = DATA_DIR / "best_model_lm.pt"
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)

# Check for mismatches
saved_vocab = checkpoint.get('vocab_size', 'NOT SAVED')
print(f"Checkpoint vocab_size: {saved_vocab}")
print(f"Tokenizer vocab_size: {VOCAB_SIZE}")
if saved_vocab != 'NOT SAVED' and saved_vocab != VOCAB_SIZE:
    print(f"⚠️ MISMATCH! Model was trained with different tokenizer!")

SEQUENCE_LENGTH = checkpoint.get('sequence_length', 512)

model = MusicGPT(
    vocab_size=VOCAB_SIZE,
    d_model=256, n_heads=8, n_layers=6,
    max_seq_len=SEQUENCE_LENGTH + 64,
    dropout=0.0  # No dropout at inference
).to(DEVICE)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\nLoaded model from epoch {checkpoint['epoch'] + 1}")
print(f"Val loss: {checkpoint['val_loss']:.4f}")

Checkpoint vocab_size: 3406
Tokenizer vocab_size: 3406

Loaded model from epoch 30
Val loss: 0.7293


## 3. Generation Function

In [4]:
# Event type to number of tokens (including event type itself)
# Based on tokenizer.events definitions
EVENT_TOKEN_COUNTS = {
    0: 1,   # PAD - shouldn't appear as event type
    1: 1,   # BOS - just the token itself, rest are PAD
    2: 1,   # EOS - just the token itself, rest are PAD
    3: 8,   # note: type + time1, time2, track, channel, pitch, velocity, duration
    4: 6,   # patch_change: type + time1, time2, track, channel, patch
    5: 7,   # control_change: type + time1, time2, track, channel, controller, value
    6: 5,   # set_tempo: type + time1, time2, track, bpm
    7: 6,   # time_signature: type + time1, time2, track, nn, dd
    8: 6,   # key_signature: type + time1, time2, track, sf, mi
}

@torch.no_grad()
def generate(model, tokenizer, seed_tokens=None, max_events=256, 
             temperature=1.0, top_k=50, top_p=0.95):
    """
    Generate MIDI tokens autoregressively, respecting the 8-token event structure.
    Forces PAD tokens based on event type to maintain proper boundaries.
    """
    tokens_per_event = tokenizer.max_token_seq  # 8
    
    # Start with BOS or seed
    if seed_tokens is None:
        tokens = [tokenizer.bos_id] + [tokenizer.pad_id] * (tokens_per_event - 1)
    else:
        tokens = list(seed_tokens)
    
    generated = list(tokens)
    events_generated = 0
    
    while events_generated < max_events:
        event_tokens = []
        event_type = None
        tokens_needed = tokens_per_event  # Default to 8
        
        for pos in range(tokens_per_event):
            # Check if we should force PAD
            if event_type is not None and pos >= EVENT_TOKEN_COUNTS.get(event_type, 8):
                # Force PAD for remaining positions
                next_token = tokenizer.pad_id
            else:
                # Generate from model
                context = generated[-SEQUENCE_LENGTH:]
                x = torch.LongTensor([context]).to(DEVICE)
                logits = model(x)[0, -1] / temperature
                
                # Top-k filtering
                if top_k > 0:
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                    logits[indices_to_remove] = float('-inf')
                
                # Top-p filtering
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
                    logits[indices_to_remove] = float('-inf')
                
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
            
            # Record event type at position 0
            if pos == 0:
                event_type = next_token
                # Stop at EOS
                if event_type == tokenizer.eos_id:
                    event_tokens = [tokenizer.eos_id] + [tokenizer.pad_id] * (tokens_per_event - 1)
                    generated.extend(event_tokens)
                    return np.array(generated).reshape(-1, tokens_per_event)
            
            event_tokens.append(next_token)
            generated.append(next_token)
        
        events_generated += 1
        
        # Sanity check for degeneration
        if event_type is not None and event_type > 20:
            print(f"⚠️ Degeneration at event {events_generated}: type={event_type}")
            generated = generated[:-tokens_per_event]
            break
    
    return np.array(generated).reshape(-1, tokens_per_event)


print("Generation function ready (with PAD enforcement)!")

Generation function ready (with PAD enforcement)!


## 4. Generate Music

In [28]:
# DEBUG: Check actual tokenizer output format FIRST
print("=== TOKENIZER OUTPUT FORMAT ===\n")

# Load a sample MIDI and see what tokens look like
sample_midi = DATA_DIR / "adl-piano-midi/Blues/Blues/Bobby Blue Bland/What A Beautiful World.mid"

if sample_midi.exists():
    tokens_2d = tokenizer.midi_file_to_tokens(sample_midi)
    print(f"Tokenized shape: {tokens_2d.shape}")
    print(f"Tokens per event: {tokens_2d.shape[1]}")
    print()
    
    print("First 10 events:")
    for i, event in enumerate(tokens_2d[:10]):
        # Check for PADs
        pad_count = sum(1 for t in event if t == 0)
        print(f"  Event {i}: {event.tolist()}  (PADs: {pad_count})")
    
    print("\n--- Checking if PAD tokens exist in events ---")
    # Count events by number of PAD tokens
    pad_counts = {}
    for event in tokens_2d:
        pads = sum(1 for t in event if t == 0)
        pad_counts[pads] = pad_counts.get(pads, 0) + 1
    
    print("Distribution of PAD counts per event:")
    for pads, count in sorted(pad_counts.items()):
        print(f"  {pads} PADs: {count} events")
    
    print("\n--- Checking event types ---")
    event_types = tokens_2d[:, 0]
    unique_types = np.unique(event_types)
    print(f"Unique event types in column 0: {unique_types.tolist()}")
else:
    print(f"Sample file not found: {sample_midi}")

=== TOKENIZER OUTPUT FORMAT ===

Tokenized shape: (386, 8)
Tokens per event: 8

First 10 events:
  Event 0: [1, 0, 0, 0, 0, 0, 0, 0]  (PADs: 7)
  Event 1: [7, 9, 137, 2201, 3374, 3387, 0, 0]  (PADs: 2)
  Event 2: [6, 9, 137, 2201, 3080, 0, 0, 0]  (PADs: 3)
  Event 3: [8, 9, 137, 2202, 3395, 3404, 0, 0]  (PADs: 2)
  Event 4: [4, 9, 137, 2202, 2329, 2605, 0, 0]  (PADs: 2)
  Event 5: [5, 9, 137, 2202, 2329, 2739, 2921, 0]  (PADs: 1)
  Event 6: [5, 9, 137, 2202, 2329, 2822, 2937, 0]  (PADs: 1)
  Event 7: [5, 9, 137, 2202, 2329, 2820, 2957, 0]  (PADs: 1)
  Event 8: [5, 9, 137, 2202, 2329, 2740, 2984, 0]  (PADs: 1)
  Event 9: [5, 9, 137, 2202, 2329, 2736, 2970, 0]  (PADs: 1)

--- Checking if PAD tokens exist in events ---
Distribution of PAD counts per event:
  0 PADs: 368 events
  1 PADs: 8 events
  2 PADs: 3 events
  3 PADs: 5 events
  7 PADs: 2 events

--- Checking event types ---
Unique event types in column 0: [1, 2, 3, 4, 5, 6, 7, 8]


In [6]:
# Generate from scratch
print("Generating from scratch...")

generated_tokens = generate(
    model, tokenizer,
    seed_tokens=None,
    max_events=500,
    temperature=1.8,
    top_k=30,
    top_p=0.95
)

print(f"\nGenerated {len(generated_tokens)} events total")

# Show first few events
print("\nFirst 5 generated events:")
for i, event in enumerate(generated_tokens[:5]):
    print(f"  Event {i}: {event.tolist()}")

Generating from scratch...
⚠️ Degeneration at event 65: type=145

Generated 65 events total

First 5 generated events:
  Event 0: [1, 0, 0, 0, 0, 0, 0, 0]
  Event 1: [7, 9, 137, 2201, 3372, 3386, 0, 0]
  Event 2: [6, 9, 137, 2201, 3107, 0, 0, 0]
  Event 3: [4, 9, 137, 2202, 2329, 2602, 0, 0]
  Event 4: [5, 9, 137, 2202, 2329, 2736, 2937, 0]


In [7]:
# Save to MIDI
output_path = OUTPUT_DIR / "generated_lm.mid"
success = tokenizer.tokens_to_midi_file(generated_tokens, output_path)

if success:
    print(f"Saved to: {output_path}")
else:
    print("Failed to save MIDI")

Saved to: midi_data/generated/generated_lm.mid


In [8]:
# Play the generated MIDI
import pretty_midi
from IPython.display import display

def _has_fluidsynth():
    """Check if fluidsynth is available."""
    try:
        import fluidsynth
        return True
    except ImportError:
        return False

def play_midi(midi_path, sample_rate=22050):
    """Play a MIDI file with fallback if fluidsynth not available."""
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    
    # Check if MIDI has any notes
    total_notes = sum(len(inst.notes) for inst in midi.instruments)
    if total_notes == 0:
        print(f"Warning: MIDI file has no notes!")
        return None
    
    print(f"MIDI: {total_notes} notes, {midi.get_end_time():.1f}s duration")

    if _has_fluidsynth():
        audio = midi.fluidsynth(fs=sample_rate)
    else:
        audio = midi.synthesize(fs=sample_rate)
    
    if len(audio) == 0:
        print("Warning: Audio synthesis produced empty output")
        return None

    return Audio(audio, rate=sample_rate)

if output_path.exists():
    print("Playing generated music:")
    player = play_midi(output_path)
    if player:
        display(player)
else:
    print(f"File not found: {output_path}")

Playing generated music:
MIDI: 49 notes, 5.3s duration


## 5. Generate with MIDI Seed

Use an existing MIDI file as a starting point for generation.

In [32]:
# WORKAROUND: Use seed for context but generate fresh from BOS
# This primes the model with the style but starts proper generation

seed_midi_path = DATA_DIR / "adl-piano-midi/Country/Irish Country/Christie Hennessy/Can This Be All There Is To Love.mid"

if seed_midi_path.exists():
    seed_tokens_2d = tokenizer.midi_file_to_tokens(seed_midi_path)
    if seed_tokens_2d is not None:
        # Take first N events as "style context"
        context_events = 256
        context_2d = seed_tokens_2d[:context_events]
        
        print(f"Using {context_events} events as style context")
        print(f"Then generating fresh from BOS\n")
        
        # Create priming sequence: context + BOS
        # The model sees the context, then BOS signals "start generating"
        context_flat = tokenizer.flatten_tokens(context_2d)
        bos_event = [tokenizer.bos_id] + [tokenizer.pad_id] * 7
        primed_tokens = list(context_flat) + bos_event
        
        # Generate using the primed context
        generated_tokens = generate(
            model, tokenizer,
            seed_tokens=primed_tokens,
            max_events=256,
            temperature=0.1,
            top_k=30,
            top_p=0.95
        )
        
        # Remove the context prefix - only keep from BOS onwards
        # Find where our BOS event starts (after context)
        bos_start_event = context_events
        fresh_tokens = generated_tokens[bos_start_event:]
        
        print(f"Generated {len(fresh_tokens)} events (context removed)")
        
        # Validate
        valid_events = np.sum(fresh_tokens[:, 0] <= 20)
        print(f"Events with valid type: {valid_events} / {len(fresh_tokens)}")
        
        output_path = OUTPUT_DIR / "generated_style_primed.mid"
        tokenizer.tokens_to_midi_file(fresh_tokens, output_path)
        print(f"Saved to: {output_path}")
        
        print("\n--- Style-primed generation ---")
        display(play_midi(output_path))
else:
    print(f"Seed file not found")

Using 125 events as style context
Then generating fresh from BOS

Generated 257 events (context removed)
Events with valid type: 257 / 257
Saved to: midi_data/generated/generated_style_primed.mid

--- Style-primed generation ---
MIDI: 256 notes, 37.5s duration


In [10]:
# COMPARE: Save just the seed as MIDI to understand timing
print("=== COMPARISON: Seed only vs Generated ===\n")

# Use the correct variable name (seed_tokens_2d from earlier cell)
if 'seed_tokens_2d' not in dir():
    print("⚠️ Run the seed loading cell first!")
else:
    # Save seed-only MIDI for comparison
    seed_only_path = OUTPUT_DIR / "seed_only.mid"
    tokenizer.tokens_to_midi_file(seed_tokens_2d[:50], seed_only_path)  # First 50 events
    print(f"Seed-only saved to: {seed_only_path}")

    # Get MIDI info using pretty_midi
    import pretty_midi
    seed_midi = pretty_midi.PrettyMIDI(str(seed_only_path))
    
    print(f"\nSeed-only MIDI (first 50 events):")
    print(f"  Duration: {seed_midi.get_end_time():.2f} seconds")
    print(f"  Notes: {sum(len(inst.notes) for inst in seed_midi.instruments)}")
    
    if output_path.exists():
        generated_midi = pretty_midi.PrettyMIDI(str(output_path))
        print(f"\nGenerated MIDI:")
        print(f"  Duration: {generated_midi.get_end_time():.2f} seconds")
        print(f"  Notes: {sum(len(inst.notes) for inst in generated_midi.instruments)}")

    print("\n--- Playing SEED ONLY for comparison ---")
    display(play_midi(seed_only_path))

=== COMPARISON: Seed only vs Generated ===

Seed-only saved to: midi_data/generated/seed_only.mid

Seed-only MIDI (first 50 events):
  Duration: 20.53 seconds
  Notes: 37

Generated MIDI:
  Duration: 10.44 seconds
  Notes: 35

--- Playing SEED ONLY for comparison ---
MIDI: 37 notes, 20.5s duration




In [11]:
# DEBUG: Analyze the TIME encoding in tokens
print("=== TOKEN TIME ANALYSIS ===\n")

if 'seed_tokens_2d' not in dir():
    print("⚠️ Run the seed loading cell first!")
else:
    # Token format: [event_type, time1, time2, track, duration, channel, pitch, velocity]
    print("SEED tokens (first 10 events):")
    for i, event in enumerate(seed_tokens_2d[:10]):
        print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}")

    print(f"\nSEED tokens (last 5 events):")
    for i, event in enumerate(seed_tokens_2d[-5:], start=len(seed_tokens_2d)-5):
        print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}")

    if 'generated_tokens' in dir() and len(generated_tokens) > 55:
        print(f"\n--- GENERATED tokens around seed boundary ---")
        context_events = 30  # Same as in generation cell
        boundary = context_events
        start_idx = max(0, boundary - 2)
        end_idx = min(len(generated_tokens), boundary + 5)
        for i, event in enumerate(generated_tokens[start_idx:end_idx], start=start_idx):
            marker = " <-- CONTEXT ENDS" if i == boundary - 1 else ""
            print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}{marker}")

    if 'generated_tokens' in dir():
        # Check if the model generated EOS early
        eos_mask = generated_tokens[:, 0] == tokenizer.eos_id
        if np.any(eos_mask):
            eos_idx = np.argmax(eos_mask)
            print(f"\n⚠️ EOS found at event {eos_idx}! Generation stopped early.")

=== TOKEN TIME ANALYSIS ===

SEED tokens (first 10 events):
  Event 0: type=  1, time1=  0, time2=  0, pitch=  0
  Event 1: type=  7, time1=  9, time2=137, pitch=  0
  Event 2: type=  6, time1=  9, time2=137, pitch=  0
  Event 3: type=  8, time1=  9, time2=137, pitch=  0
  Event 4: type=  4, time1=  9, time2=137, pitch=  0
  Event 5: type=  5, time1=  9, time2=137, pitch=2921
  Event 6: type=  5, time1=  9, time2=137, pitch=2937
  Event 7: type=  5, time1=  9, time2=137, pitch=2957
  Event 8: type=  5, time1=  9, time2=137, pitch=2984
  Event 9: type=  5, time1=  9, time2=137, pitch=2970

SEED tokens (last 5 events):
  Event 381: type=  3, time1=  9, time2=137, pitch=2535
  Event 382: type=  3, time1=  9, time2=137, pitch=2550
  Event 383: type=  3, time1=  9, time2=137, pitch=2550
  Event 384: type=  6, time1= 12, time2=137, pitch=  0
  Event 385: type=  2, time1=  0, time2=  0, pitch=  0

--- GENERATED tokens around seed boundary ---
  Event 28: type=  3, time1= 11, time2=145, pitch=

In [12]:
play_midi(output_path)

MIDI: 35 notes, 10.4s duration


## 6. Explore Different Temperatures

In [13]:
temperatures = [0.5, 0.8, 1.0, 1.2]

for temp in temperatures:
    print(f"\nGenerating with temperature {temp}...")
    tokens = generate(model, tokenizer, max_events=300, temperature=temp)
    
    output_path = OUTPUT_DIR / f"generated_lm_temp_{temp}.mid"
    tokenizer.tokens_to_midi_file(tokens, output_path)
    print(f"  Saved: {output_path.name}")

print(f"\nAll files saved to: {OUTPUT_DIR}")


Generating with temperature 0.5...
  Saved: generated_lm_temp_0.5.mid

Generating with temperature 0.8...
  Saved: generated_lm_temp_0.8.mid

Generating with temperature 1.0...
⚠️ Degeneration at event 67: type=137
  Saved: generated_lm_temp_1.0.mid

Generating with temperature 1.2...
⚠️ Degeneration at event 84: type=137
  Saved: generated_lm_temp_1.2.mid

All files saved to: midi_data/generated


## Tips

**Temperature guide:**
- `0.5` - Very coherent, conservative
- `0.8-0.9` - Balanced (recommended)
- `1.0` - More varied
- `1.2+` - Experimental, may be chaotic

**For better results:**
- Train for more epochs
- Use larger dataset
- Try different seeds from different genres