# MIDI Music Generation - Generate with Language Model

Generate music using the trained MusicGPT model and SkyTNT tokenizer.

**Requirements:** Trained model at `midi_data/best_model_lm.pt`

## 1. Setup

In [None]:
import setproctitle # we installed this package already, see above
setproctitle.setproctitle('midi-gen - sprec1')


In [1]:
import random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import Audio
import matplotlib.pyplot as plt

from midi_tokenizer_wrapper import MIDITokenizerWrapper

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Device: {DEVICE}")

DATA_DIR = Path("./midi_data")
OUTPUT_DIR = DATA_DIR / "generated"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

Device: mps


## 2. Load Model & Tokenizer

In [2]:
# Initialize tokenizer
tokenizer = MIDITokenizerWrapper(version="v2", optimise_midi=True)
VOCAB_SIZE = tokenizer.vocab_size
PAD_ID = tokenizer.pad_id
BOS_ID = tokenizer.bos_id
EOS_ID = tokenizer.eos_id

print(f"Vocabulary: {VOCAB_SIZE} tokens")
print(f"Special: PAD={PAD_ID}, BOS={BOS_ID}, EOS={EOS_ID}")

Vocabulary: 3406 tokens
Special: PAD=0, BOS=1, EOS=2


In [3]:
# Model definition (same as training)
class MusicGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, 
                 max_seq_len=1024, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)
        
        self.register_buffer(
            'causal_mask',
            torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        )
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.token_emb(x) + self.pos_emb(positions)
        x = self.dropout(x)
        mask = self.causal_mask[:seq_len, :seq_len]
        x = self.transformer(x, x, tgt_mask=mask, memory_mask=mask)
        return self.lm_head(x)

# Load model
MODEL_PATH = DATA_DIR / "best_model_lm.pt"
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)

SEQUENCE_LENGTH = checkpoint.get('sequence_length', 512)

model = MusicGPT(
    vocab_size=VOCAB_SIZE,
    d_model=256, n_heads=8, n_layers=6,
    max_seq_len=SEQUENCE_LENGTH + 64,
    dropout=0.0  # No dropout at inference
).to(DEVICE)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded model from epoch {checkpoint['epoch'] + 1}")
print(f"Val loss: {checkpoint['val_loss']:.4f}")

Loaded model from epoch 9
Val loss: 1.0365


## 3. Generation Function

In [4]:
@torch.no_grad()
def generate(model, tokenizer, seed_tokens=None, max_events=256, 
             temperature=1.0, top_k=50, top_p=0.95):
    """
    Generate MIDI tokens autoregressively, respecting the 8-token event structure.
    
    Args:
        model: Trained model
        tokenizer: MIDITokenizerWrapper
        seed_tokens: Optional starting tokens (1D array, must be multiple of 8)
        max_events: Max events to generate
        temperature: Sampling temperature (lower = more deterministic)
        top_k: Keep only top k tokens
        top_p: Keep tokens with cumulative prob <= p
    
    Returns:
        Token sequence (2D: events x tokens_per_event)
    """
    tokens_per_event = tokenizer.max_token_seq  # 8
    
    # Start with BOS or seed
    if seed_tokens is None:
        tokens = [tokenizer.bos_id] + [tokenizer.pad_id] * (tokens_per_event - 1)
    else:
        tokens = list(seed_tokens)
    
    generated = list(tokens)
    events_generated = 0
    
    while events_generated < max_events:
        # Generate one complete event (8 tokens)
        event_tokens = []
        
        for pos in range(tokens_per_event):
            # Prepare input (last SEQUENCE_LENGTH tokens)
            context = generated[-SEQUENCE_LENGTH:]
            x = torch.LongTensor([context]).to(DEVICE)
            
            # Get logits for next token
            logits = model(x)[0, -1] / temperature
            
            # Top-k filtering
            if top_k > 0:
                indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                logits[indices_to_remove] = float('-inf')
            
            # Top-p (nucleus) filtering
            if top_p < 1.0:
                sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0
                indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
                logits[indices_to_remove] = float('-inf')
            
            # Sample
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            
            event_tokens.append(next_token)
            generated.append(next_token)
            
            # Stop at EOS (only check first position of event)
            if pos == 0 and next_token == tokenizer.eos_id:
                # Pad rest of event and stop
                event_tokens.extend([tokenizer.pad_id] * (tokens_per_event - 1))
                generated.extend([tokenizer.pad_id] * (tokens_per_event - 1))
                return np.array(generated).reshape(-1, tokens_per_event)
        
        events_generated += 1
        
        # Sanity check: if event type is invalid (>100), we're degenerating
        if event_tokens[0] > 100:
            print(f"⚠️ Degeneration detected at event {events_generated}: type={event_tokens[0]}")
            print(f"   Stopping generation early.")
            # Remove the bad event
            generated = generated[:-tokens_per_event]
            break
    
    return np.array(generated).reshape(-1, tokens_per_event)


print("Generation function ready (with structure awareness)!")

Generation function ready (with structure awareness)!


## 4. Generate Music

In [None]:
# Generate from scratch - test stability
print("Generating from scratch - testing how long before degeneration...")

generated_tokens = generate(
    model, tokenizer,
    seed_tokens=None,
    max_events=500,  # Try to generate a lot
    temperature=0.7,
    top_k=30,
    top_p=0.95
)

print(f"\nGenerated {len(generated_tokens)} events total")

# Check where degeneration starts (if any)
invalid_mask = generated_tokens[:, 0] > 20
if np.any(invalid_mask):
    first_invalid = np.argmax(invalid_mask)
    print(f"First invalid event type at index {first_invalid}")
else:
    print("All events have valid types!")

# Show distribution of event types
event_types = generated_tokens[:, 0]
unique, counts = np.unique(event_types, return_counts=True)
print(f"\nEvent type distribution:")
for t, c in zip(unique[:10], counts[:10]):  # First 10
    print(f"  Type {t}: {c} events")

In [None]:
# Save to MIDI
output_path = OUTPUT_DIR / "generated_lm.mid"
success = tokenizer.tokens_to_midi_file(generated_tokens, output_path)

if success:
    print(f"Saved to: {output_path}")
else:
    print("Failed to save MIDI")

In [None]:
# Play the generated MIDI
import pretty_midi
from IPython.display import display

def _has_fluidsynth():
    """Check if fluidsynth is available."""
    try:
        import fluidsynth
        return True
    except ImportError:
        return False

def play_midi(midi_path, sample_rate=22050):
    """Play a MIDI file with fallback if fluidsynth not available."""
    midi = pretty_midi.PrettyMIDI(str(midi_path))

    if _has_fluidsynth():
        audio = midi.fluidsynth(fs=sample_rate)
    else:
        # Fallback to basic synthesis
        audio = midi.synthesize(fs=sample_rate)

    return Audio(audio, rate=sample_rate)

if output_path.exists():
    print("Playing generated music:")
    display(play_midi(output_path))  # Use display() to show the player

## 5. Generate with MIDI Seed

Use an existing MIDI file as a starting point for generation.

In [None]:
# WORKAROUND: Use seed for context but generate fresh from BOS
# This primes the model with the style but starts proper generation

seed_midi_path = DATA_DIR / "adl-piano-midi/Jazz/Jazz Blues/Aretha Franklin/I Never Loved A Man The Way I Love You Stereo Version.mid"

if seed_midi_path.exists():
    seed_tokens_2d = tokenizer.midi_file_to_tokens(seed_midi_path)
    if seed_tokens_2d is not None:
        # Take first N events as "style context"
        context_events = 30
        context_2d = seed_tokens_2d[:context_events]
        
        print(f"Using {context_events} events as style context")
        print(f"Then generating fresh from BOS\n")
        
        # Create priming sequence: context + BOS
        # The model sees the context, then BOS signals "start generating"
        context_flat = tokenizer.flatten_tokens(context_2d)
        bos_event = [tokenizer.bos_id] + [tokenizer.pad_id] * 7
        primed_tokens = list(context_flat) + bos_event
        
        # Generate using the primed context
        generated_tokens = generate(
            model, tokenizer,
            seed_tokens=primed_tokens,
            max_events=200,
            temperature=0.7,
            top_k=30,
            top_p=0.95
        )
        
        # Remove the context prefix - only keep from BOS onwards
        # Find where our BOS event starts (after context)
        bos_start_event = context_events
        fresh_tokens = generated_tokens[bos_start_event:]
        
        print(f"Generated {len(fresh_tokens)} events (context removed)")
        
        # Validate
        valid_events = np.sum(fresh_tokens[:, 0] <= 20)
        print(f"Events with valid type: {valid_events} / {len(fresh_tokens)}")
        
        output_path = OUTPUT_DIR / "generated_style_primed.mid"
        tokenizer.tokens_to_midi_file(fresh_tokens, output_path)
        print(f"Saved to: {output_path}")
        
        print("\n--- Style-primed generation ---")
        display(play_midi(output_path))
else:
    print(f"Seed file not found")

In [9]:
# COMPARE: Save just the seed as MIDI to understand timing
print("=== COMPARISON: Seed only vs Generated ===\n")

# Save seed-only MIDI for comparison
seed_only_path = OUTPUT_DIR / "seed_only.mid"
tokenizer.tokens_to_midi_file(seed_2d, seed_only_path)
print(f"Seed-only saved to: {seed_only_path}")

# Get MIDI info using pretty_midi
import pretty_midi
seed_midi = pretty_midi.PrettyMIDI(str(seed_only_path))
generated_midi = pretty_midi.PrettyMIDI(str(output_path))

print(f"\nSeed-only MIDI:")
print(f"  Duration: {seed_midi.get_end_time():.2f} seconds")
print(f"  Notes: {sum(len(inst.notes) for inst in seed_midi.instruments)}")

print(f"\nGenerated MIDI:")
print(f"  Duration: {generated_midi.get_end_time():.2f} seconds")
print(f"  Notes: {sum(len(inst.notes) for inst in generated_midi.instruments)}")

# Check timing of first few notes in generated
if generated_midi.instruments:
    notes = sorted(generated_midi.instruments[0].notes, key=lambda n: n.start)
    if notes:
        print(f"\nFirst 10 note start times in generated:")
        for n in notes[:10]:
            print(f"  t={n.start:.3f}s pitch={n.pitch}")
        
        # Find gap in timing
        print(f"\nLast 10 note start times:")
        for n in notes[-10:]:
            print(f"  t={n.start:.3f}s pitch={n.pitch}")

print("\n--- Playing SEED ONLY for comparison ---")
display(play_midi(seed_only_path))

=== COMPARISON: Seed only vs Generated ===



NameError: name 'seed_2d' is not defined

In [10]:
# DEBUG: Analyze the TIME encoding in tokens
print("=== TOKEN TIME ANALYSIS ===\n")

# Token format: [event_type, time1, time2, track, duration, channel, pitch, velocity]
# time1 and time2 together encode the absolute time

print("SEED tokens (first 10 events) - time1, time2 columns:")
for i, event in enumerate(seed_2d[:10]):
    print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}")
yo
print(f"\nSEED tokens (last 5 events):")
for i, event in enumerate(seed_2d[-5:], start=len(seed_2d)-5):
    print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}")

print(f"\n--- GENERATED tokens around boundary (events 48-55) ---")
for i, event in enumerate(generated_tokens[48:55], start=48):
    marker = " <-- SEED ENDS" if i == 49 else ""
    print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}{marker}")

print(f"\nGENERATED tokens (last 10 events):")
for i, event in enumerate(generated_tokens[-10:], start=len(generated_tokens)-10):
    print(f"  Event {i}: type={event[0]:3d}, time1={event[1]:3d}, time2={event[2]:3d}, pitch={event[6]:3d}")

# Check if the model generated EOS early
eos_mask = generated_tokens[:, 0] == tokenizer.eos_id
if np.any(eos_mask):
    eos_idx = np.argmax(eos_mask)
    print(f"\n⚠️ EOS found at event {eos_idx}! Generation stopped early.")
    print(f"   Expected ~{50 + 400} events, got {len(generated_tokens)}")

=== TOKEN TIME ANALYSIS ===

SEED tokens (first 10 events) - time1, time2 columns:


NameError: name 'seed_2d' is not defined

In [11]:
play_midi(output_path)

## 6. Explore Different Temperatures

In [None]:
temperatures = [0.5, 0.8, 1.0, 1.2]

for temp in temperatures:
    print(f"\nGenerating with temperature {temp}...")
    tokens = generate(model, tokenizer, max_events=300, temperature=temp)
    
    output_path = OUTPUT_DIR / f"generated_lm_temp_{temp}.mid"
    tokenizer.tokens_to_midi_file(tokens, output_path)
    print(f"  Saved: {output_path.name}")

print(f"\nAll files saved to: {OUTPUT_DIR}")

## Tips

**Temperature guide:**
- `0.5` - Very coherent, conservative
- `0.8-0.9` - Balanced (recommended)
- `1.0` - More varied
- `1.2+` - Experimental, may be chaotic

**For better results:**
- Train for more epochs
- Use larger dataset
- Try different seeds from different genres