# MIDI Music Generation

Generate piano music using the trained MusicGPT model with constrained decoding.

## Features
- **From scratch**: Generate new music starting from nothing
- **Seed continuation**: Continue an existing MIDI file in a similar style
- **Temperature control**: Adjust creativity vs coherence
- **Constrained decoding**: Enforces musical structure and variety

## Requirements
- Trained model at `midi_data/best_model_lm.pt`
- MIDI files for seeding (optional)

## 1. Setup

In [10]:
import random
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import Audio
import matplotlib.pyplot as plt

# Set to None for random generation each time, or a number for reproducibility
SEED = None  # e.g., 42 for reproducible, None for random

if SEED is not None:
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    print(f"Using fixed seed: {SEED} (reproducible)")
else:
    # Use current time as seed for variety
    random_seed = int(time.time()) % 100000
    torch.manual_seed(random_seed)
    print(f"Using random seed: {random_seed} (different each run)")

# Device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Device: {DEVICE}")

DATA_DIR = Path("./midi_data")
OUTPUT_DIR = DATA_DIR / "generated"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

Using random seed: 15488 (different each run)
Device: mps


## 2. Load Model & Tokenizer

In [11]:
from midi_tokenizer import MIDITokenizer

tokenizer = MIDITokenizer()

VOCAB_SIZE = tokenizer.vocab_size
PAD_ID = tokenizer.pad_id
BOS_ID = tokenizer.bos_id
EOS_ID = tokenizer.eos_id

print(f"Vocabulary: {VOCAB_SIZE} tokens")
print(f"Special: PAD={PAD_ID}, BOS={BOS_ID}, EOS={EOS_ID}")

Vocabulary: 3406 tokens
Special: PAD=0, BOS=1, EOS=2


In [12]:
# Import model from shared module
import sys
sys.path.insert(0, str(DATA_DIR.parent / "midi-model"))
from model import MusicGPT

# Load checkpoint
MODEL_PATH = DATA_DIR / "best_model_lm.pt"
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)

# Check for mismatches
saved_vocab = checkpoint.get('vocab_size', 'NOT SAVED')
print(f"Checkpoint vocab_size: {saved_vocab}")
print(f"Tokenizer vocab_size: {VOCAB_SIZE}")
if saved_vocab != 'NOT SAVED' and saved_vocab != VOCAB_SIZE:
    print(f"‚ö†Ô∏è MISMATCH! Model was trained with different tokenizer!")

SEQUENCE_LENGTH = checkpoint.get('sequence_length', 512)
TOKENS_PER_EVENT = checkpoint.get('tokens_per_event', 8)

model = MusicGPT(
    vocab_size=VOCAB_SIZE,
    d_model=256, n_heads=8, n_layers=6,
    max_seq_len=SEQUENCE_LENGTH + 64,
    dropout=0.0,  # No dropout at inference
    tokens_per_event=TOKENS_PER_EVENT
).to(DEVICE)

# Try to load - may fail if old model without event_pos_emb
try:
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"\n‚úì Loaded model from epoch {checkpoint['epoch'] + 1}")
    print(f"Val loss: {checkpoint['val_loss']:.4f}")
except RuntimeError as e:
    if 'event_pos_emb' in str(e):
        print(f"\n‚ö†Ô∏è Old model without position-in-event embedding!")
        print("Need to retrain with updated architecture.")
    else:
        raise e

model.eval()

Checkpoint vocab_size: 3406
Tokenizer vocab_size: 3406

‚úì Loaded model from epoch 30
Val loss: 0.7274


MusicGPT(
  (token_emb): Embedding(3406, 256, padding_idx=0)
  (pos_emb): Embedding(576, 256)
  (event_pos_emb): Embedding(8, 256)
  (dropout): Dropout(p=0.0, inplace=False)
  (transformer): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=1024, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (linear2): Linear(in_features=1024, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      

## 3. Generation Function

The generation function uses **constrained decoding** to improve output quality:

| Constraint | Purpose |
|------------|---------|
| `min_note_ratio` | Ensures enough musical notes vs metadata |
| `ctrl_penalty` | Reduces excessive control events (pedal spam) |
| `pitch_repeat_penalty` | Forces melodic variety |
| `top_k` / `top_p` | Controls randomness |
| `temperature` | Lower = conservative, Higher = creative |

In [13]:
# Event type to number of tokens (including event type itself)
EVENT_TOKEN_COUNTS = {
    0: 1, 1: 1, 2: 1,  # PAD, BOS, EOS
    3: 8,   # note
    4: 6,   # patch_change
    5: 7,   # control_change
    6: 5,   # set_tempo
    7: 6,   # time_signature
    8: 6,   # key_signature
}

NOTE_TYPE, CTRL_TYPE, EOS_TYPE = 3, 5, 2

# Pitch token range (from tokenizer)
PITCH_START = 2329  # First pitch token
PITCH_END = 2457    # Last pitch token (128 pitches)

@torch.no_grad()
def generate(model, tokenizer, seed_tokens=None, max_events=256, 
             temperature=1.0, top_k=50, top_p=0.95,
             min_note_ratio=0.7, ctrl_penalty=5.0,
             pitch_repeat_penalty=2.0, pitch_memory=8):
    """
    Generate MIDI tokens with constraints.
    
    New args:
        pitch_repeat_penalty: Penalty for recently used pitches
        pitch_memory: How many recent pitches to remember
    """
    tokens_per_event = tokenizer.max_token_seq
    
    if seed_tokens is None:
        tokens = [tokenizer.bos_id] + [tokenizer.pad_id] * (tokens_per_event - 1)
    else:
        tokens = list(seed_tokens)
    
    generated = list(tokens)
    events_generated = 0
    note_count = 0
    total_musical_events = 0
    recent_pitches = []  # Track recent pitches for variety
    
    while events_generated < max_events:
        event_tokens = []
        event_type = None
        
        for pos in range(tokens_per_event):
            if event_type is not None and pos >= EVENT_TOKEN_COUNTS.get(event_type, 8):
                next_token = tokenizer.pad_id
            else:
                context = generated[-SEQUENCE_LENGTH:]
                x = torch.LongTensor([context]).to(DEVICE)
                logits = model(x)[0, -1] / temperature
                
                # === POSITION 0: Event type constraints ===
                if pos == 0:
                    current_ratio = note_count / max(1, total_musical_events)
                    logits[CTRL_TYPE] -= ctrl_penalty
                    
                    if total_musical_events > 5 and current_ratio < min_note_ratio:
                        logits[NOTE_TYPE] += 3.0
                        for t in [4, 5, 6, 7, 8]:
                            logits[t] -= 2.0
                
                # === POSITION 5: Pitch constraints (for note events) ===
                if pos == 5 and event_type == NOTE_TYPE and pitch_repeat_penalty > 0:
                    # Penalize recently used pitches
                    for recent_pitch in recent_pitches:
                        if PITCH_START <= recent_pitch <= PITCH_END:
                            logits[recent_pitch] -= pitch_repeat_penalty
                
                # Top-k
                if top_k > 0:
                    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
                    logits[indices_to_remove] = float('-inf')
                
                # Top-p
                if top_p < 1.0:
                    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
                    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
                    sorted_indices_to_remove = cumulative_probs > top_p
                    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                    sorted_indices_to_remove[..., 0] = 0
                    indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
                    logits[indices_to_remove] = float('-inf')
                
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1).item()
                
                # Track pitch for variety
                if pos == 5 and event_type == NOTE_TYPE:
                    recent_pitches.append(next_token)
                    if len(recent_pitches) > pitch_memory:
                        recent_pitches.pop(0)
            
            if pos == 0:
                event_type = next_token
                if event_type == NOTE_TYPE:
                    note_count += 1
                    total_musical_events += 1
                elif event_type in [4, 5]:
                    total_musical_events += 1
                
                if event_type == EOS_TYPE:
                    event_tokens = [EOS_TYPE] + [tokenizer.pad_id] * (tokens_per_event - 1)
                    generated.extend(event_tokens)
                    return np.array(generated).reshape(-1, tokens_per_event)
            
            event_tokens.append(next_token)
            generated.append(next_token)
        
        events_generated += 1
        
        if event_type is not None and event_type > 20:
            print(f"‚ö†Ô∏è Degeneration at event {events_generated}: type={event_type}")
            generated = generated[:-tokens_per_event]
            break
    
    final_ratio = note_count / max(1, total_musical_events)
    unique_pitches = len(set(recent_pitches))
    print(f"üìä Note ratio: {note_count}/{total_musical_events} = {final_ratio:.1%}")
    print(f"üéπ Pitch variety: {unique_pitches} unique in last {len(recent_pitches)}")
    
    return np.array(generated).reshape(-1, tokens_per_event)

print("Generation with pitch repetition penalty ready!")

Generation with pitch repetition penalty ready!


## 4. Generate Music

In [14]:
# Generate from scratch with all constraints
print("Generating from scratch...")

generated_tokens = generate(
    model, tokenizer,
    seed_tokens=None,
    max_events=300,
    temperature=0.9,          # Higher for variety
    top_k=50,                 # More options
    top_p=0.95,
    min_note_ratio=0.8,
    ctrl_penalty=10.0,
    pitch_repeat_penalty=3.0,  # Penalize same pitch
    pitch_memory=12            # Remember last 12 pitches
)

print(f"\nGenerated {len(generated_tokens)} events")

# Analyze pitch variety
pitches = [e[5] for e in generated_tokens if e[0] == NOTE_TYPE]
unique_pitches = len(set(pitches))
print(f"Total notes: {len(pitches)}, Unique pitches: {unique_pitches}")

Generating from scratch...

Generated 89 events
Total notes: 82, Unique pitches: 42


In [15]:
# Save to MIDI
output_path = OUTPUT_DIR / "generated_lm.mid"
success = tokenizer.tokens_to_midi_file(generated_tokens, output_path)

if success:
    print(f"Saved to: {output_path}")
else:
    print("Failed to save MIDI")

Saved to: midi_data/generated/generated_lm.mid


In [16]:
# Play the generated MIDI
import pretty_midi
from IPython.display import display

def _has_fluidsynth():
    """Check if fluidsynth is available."""
    try:
        import fluidsynth
        return True
    except ImportError:
        return False

def play_midi(midi_path, sample_rate=22050):
    """Play a MIDI file with fallback if fluidsynth not available."""
    midi = pretty_midi.PrettyMIDI(str(midi_path))
    
    # Check if MIDI has any notes
    total_notes = sum(len(inst.notes) for inst in midi.instruments)
    if total_notes == 0:
        print(f"Warning: MIDI file has no notes!")
        return None
    
    print(f"MIDI: {total_notes} notes, {midi.get_end_time():.1f}s duration")

    if _has_fluidsynth():
        audio = midi.fluidsynth(fs=sample_rate)
    else:
        audio = midi.synthesize(fs=sample_rate)
    
    if len(audio) == 0:
        print("Warning: Audio synthesis produced empty output")
        return None

    return Audio(audio, rate=sample_rate)

if output_path.exists():
    print("Playing generated music:")
    player = play_midi(output_path)
    if player:
        display(player)
else:
    print(f"File not found: {output_path}")

Playing generated music:
MIDI: 58 notes, 14.5s duration


## 5. Generate with MIDI Seed

Use an existing MIDI file as a starting point for generation.

In [17]:
# Generate with MIDI seed - with pitch variety constraint
# Use a file we know exists (from the debug cell)
seed_midi_path = DATA_DIR / "adl-piano-midi/Reggae/Reggae/Bob Marley/Buffalo Soldier.mid"

if seed_midi_path.exists():
    seed_tokens_2d = tokenizer.midi_file_to_tokens(seed_midi_path)
    if seed_tokens_2d is not None:
        context_events = min(30, len(seed_tokens_2d) - 1)
        context_2d = seed_tokens_2d[:context_events]
        
        print(f"Seed: {len(seed_tokens_2d)} events, using {context_events} as context")
        
        context_flat = tokenizer.flatten_tokens(context_2d)
        
        generated_tokens = generate(
            model, tokenizer,
            seed_tokens=list(context_flat),
            max_events=200,
            temperature=0.9,
            top_k=50,
            top_p=0.95,
            min_note_ratio=0.8,
            ctrl_penalty=10.0,
            pitch_repeat_penalty=3.0,
            pitch_memory=12
        )
        
        # Analyze NEW portion only
        new_tokens = generated_tokens[context_events:]
        new_pitches = [e[5] for e in new_tokens if e[0] == NOTE_TYPE]
        
        print(f"\nGenerated {len(new_tokens)} new events")
        print(f"New notes: {len(new_pitches)}, Unique pitches: {len(set(new_pitches))}")
        
        # Show pitch distribution
        if new_pitches:
            from collections import Counter
            pitch_counts = Counter(new_pitches).most_common(5)
            print(f"Top 5 pitches: {pitch_counts}")
        
        output_path = OUTPUT_DIR / "generated_style_primed.mid"
        tokenizer.tokens_to_midi_file(generated_tokens, output_path)
        print(f"\nSaved to: {output_path}")
        display(play_midi(output_path))
else:
    print(f"Seed not found: {seed_midi_path}")

Seed: 341 events, using 30 as context
‚ö†Ô∏è Degeneration at event 36: type=2203
üìä Note ratio: 35/35 = 100.0%
üéπ Pitch variety: 9 unique in last 12

Generated 35 new events
New notes: 35, Unique pitches: 20
Top 5 pitches: [(np.int64(2409), 4), (np.int64(2402), 4), (np.int64(2399), 3), (np.int64(2411), 3), (np.int64(2416), 2)]

Saved to: midi_data/generated/generated_style_primed.mid
MIDI: 58 notes, 9.8s duration


## 6. Explore Different Temperatures

In [18]:
temperatures = [0.5, 0.8, 1.0, 1.2]

for temp in temperatures:
    print(f"\nGenerating with temperature {temp}...")
    tokens = generate(model, tokenizer, max_events=300, temperature=temp)
    
    output_path = OUTPUT_DIR / f"generated_lm_temp_{temp}.mid"
    tokenizer.tokens_to_midi_file(tokens, output_path)
    print(f"  Saved: {output_path.name}")

print(f"\nAll files saved to: {OUTPUT_DIR}")


Generating with temperature 0.5...
  Saved: generated_lm_temp_0.5.mid

Generating with temperature 0.8...
  Saved: generated_lm_temp_0.8.mid

Generating with temperature 1.0...
‚ö†Ô∏è Degeneration at event 66: type=2204
üìä Note ratio: 61/63 = 96.8%
üéπ Pitch variety: 7 unique in last 8
  Saved: generated_lm_temp_1.0.mid

Generating with temperature 1.2...
  Saved: generated_lm_temp_1.2.mid

All files saved to: midi_data/generated


## Tips & Parameters

### Temperature Guide
| Value | Style | Best For |
|-------|-------|----------|
| 0.5 | Very conservative | Coherent, safe output |
| 0.8 | Balanced | General use (recommended) |
| 1.0 | Creative | More variation |
| 1.2+ | Experimental | Wild, potentially chaotic |

### Constraint Parameters
```python
min_note_ratio=0.8    # 80% of events should be notes
ctrl_penalty=10.0     # Strongly discourage control events
pitch_repeat_penalty=3.0  # Discourage same pitch repeatedly
pitch_memory=12       # Remember last 12 pitches
```

### Known Limitations
- Model may degenerate after 60-100 events
- Works best with seed context from training data genre
- Training was on ~11k piano MIDI files, 35 epochs