# MIDI Music Generation - Generate with Language Model

Generate music using the trained MusicGPT model and SkyTNT tokenizer.

**Requirements:** Trained model at `midi_data/best_model_lm.pt`

## 1. Setup

In [1]:
import random
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from IPython.display import Audio
import matplotlib.pyplot as plt

from midi_tokenizer_wrapper import MIDITokenizerWrapper

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Device
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
elif torch.backends.mps.is_available():
    DEVICE = torch.device('mps')
else:
    DEVICE = torch.device('cpu')
print(f"Device: {DEVICE}")

DATA_DIR = Path("./midi_data")
OUTPUT_DIR = DATA_DIR / "generated"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

Device: mps


## 2. Load Model & Tokenizer

In [2]:
# Initialize tokenizer
tokenizer = MIDITokenizerWrapper(version="v2", optimise_midi=True)
VOCAB_SIZE = tokenizer.vocab_size
PAD_ID = tokenizer.pad_id
BOS_ID = tokenizer.bos_id
EOS_ID = tokenizer.eos_id

print(f"Vocabulary: {VOCAB_SIZE} tokens")
print(f"Special: PAD={PAD_ID}, BOS={BOS_ID}, EOS={EOS_ID}")

Vocabulary: 3406 tokens
Special: PAD=0, BOS=1, EOS=2


In [3]:
# Model definition (same as training)
class MusicGPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, n_heads=8, n_layers=6, 
                 max_seq_len=1024, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len
        
        self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.dropout = nn.Dropout(dropout)
        
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model, nhead=n_heads,
            dim_feedforward=d_model * 4,
            dropout=dropout, batch_first=True
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)
        
        self.register_buffer(
            'causal_mask',
            torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
        )
    
    def forward(self, x):
        batch_size, seq_len = x.shape
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        x = self.token_emb(x) + self.pos_emb(positions)
        x = self.dropout(x)
        mask = self.causal_mask[:seq_len, :seq_len]
        x = self.transformer(x, x, tgt_mask=mask, memory_mask=mask)
        return self.lm_head(x)

# Load model
MODEL_PATH = DATA_DIR / "best_model_lm.pt"
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)

SEQUENCE_LENGTH = checkpoint.get('sequence_length', 512)

model = MusicGPT(
    vocab_size=VOCAB_SIZE,
    d_model=256, n_heads=8, n_layers=6,
    max_seq_len=SEQUENCE_LENGTH + 64,
    dropout=0.0  # No dropout at inference
).to(DEVICE)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"Loaded model from epoch {checkpoint['epoch'] + 1}")
print(f"Val loss: {checkpoint['val_loss']:.4f}")

Loaded model from epoch 7
Val loss: 1.0762


## 3. Generation Function

In [4]:
@torch.no_grad()
def generate(model, tokenizer, seed_tokens=None, max_events=256, 
             temperature=1.0, top_k=50, top_p=0.95):
    """
    Generate MIDI tokens autoregressively.
    
    Args:
        model: Trained model
        tokenizer: MIDITokenizerWrapper
        seed_tokens: Optional starting tokens (1D array)
        max_events: Max events to generate (~8 tokens per event)
        temperature: Sampling temperature (lower = more deterministic)
        top_k: Keep only top k tokens
        top_p: Keep tokens with cumulative prob <= p
    
    Returns:
        Token sequence (2D: events x tokens_per_event)
    """
    max_tokens = max_events * tokenizer.max_token_seq
    
    # Start with BOS or seed
    if seed_tokens is None:
        # Start with just BOS token padded
        tokens = [tokenizer.bos_id] + [tokenizer.pad_id] * (tokenizer.max_token_seq - 1)
    else:
        tokens = list(seed_tokens)
    
    generated = list(tokens)
    
    for _ in range(max_tokens):
        # Prepare input (last SEQUENCE_LENGTH tokens)
        context = generated[-SEQUENCE_LENGTH:]
        x = torch.LongTensor([context]).to(DEVICE)
        
        # Get logits for next token
        logits = model(x)[0, -1] / temperature
        
        # Top-k filtering
        if top_k > 0:
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = float('-inf')
        
        # Top-p (nucleus) filtering
        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            indices_to_remove = sorted_indices_to_remove.scatter(0, sorted_indices, sorted_indices_to_remove)
            logits[indices_to_remove] = float('-inf')
        
        # Sample
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).item()
        
        generated.append(next_token)
        
        # Stop at EOS
        if next_token == tokenizer.eos_id:
            break
    
    # Reshape to events
    result = np.array(generated)
    # Pad to multiple of max_token_seq
    pad_len = (tokenizer.max_token_seq - len(result) % tokenizer.max_token_seq) % tokenizer.max_token_seq
    if pad_len > 0:
        result = np.pad(result, (0, pad_len), constant_values=tokenizer.pad_id)
    
    return result.reshape(-1, tokenizer.max_token_seq)

print("Generation function ready!")

Generation function ready!


## 4. Generate Music

In [5]:
# Generate from scratch
print("Generating music from scratch...")
generated_tokens = generate(
    model, tokenizer,
    max_events=500,  # ~500 notes/events
    temperature=0.9,
    top_k=50,
    top_p=0.95
)

print(f"Generated {len(generated_tokens)} events")
print(f"First 10 events:\n{generated_tokens[:10]}")

Generating music from scratch...
Generated 501 events
First 10 events:
[[   1    0    0    0    0    0    0    0]
 [   7    9  137 2201 3372 3386 3404 3404]
 [   8    9  137 2201 3396 3404 3404 3404]
 [3404 3404  137 2201 3394 3404 3404  185]
 [   4    9  137 2202 2329 2601 2600  185]
 [   4    9  137 2203 2330 2601 2600  185]
 [   4    9  137 2203 2330 2601 2600  185]
 [   3    9  137 2203 2330 2388 2600  185]
 [   3   10  137 2202 2329 2405 2600  169]
 [   3   10  137 2202 2329 2405 2600  169]]


In [6]:
# Save to MIDI
output_path = OUTPUT_DIR / "generated_lm.mid"
success = tokenizer.tokens_to_midi_file(generated_tokens, output_path)

if success:
    print(f"Saved to: {output_path}")
else:
    print("Failed to save MIDI")

Saved to: midi_data/generated/generated_lm.mid


In [7]:
# Play the generated MIDI
import pretty_midi

def _has_fluidsynth():
    """Check if fluidsynth is available."""
    try:
        import fluidsynth
        return True
    except ImportError:
        return False

def play_midi(midi_path, sample_rate=22050):
    """Play a MIDI file with fallback if fluidsynth not available."""
    midi = pretty_midi.PrettyMIDI(str(midi_path))

    if _has_fluidsynth():
        audio = midi.fluidsynth(fs=sample_rate)
    else:
        # Fallback to basic synthesis
        audio = midi.synthesize(fs=sample_rate)

    return Audio(audio, rate=sample_rate)

if output_path.exists():
    print("Playing generated music:")
    play_midi(output_path)

Playing generated music:


## 5. Generate with MIDI Seed

Use an existing MIDI file as a starting point for generation.

In [8]:
# Load a seed MIDI
seed_midi_path = DATA_DIR / "adl-piano-midi/Classical/Classical/Johann Sebastian Bach/Aria, BWV 988.mid"

if seed_midi_path.exists():
    seed_tokens_2d = tokenizer.midi_file_to_tokens(seed_midi_path)
    if seed_tokens_2d is not None:
        # Use first ~50 events as seed
        seed_events = 50
        seed_2d = seed_tokens_2d[:seed_events]
        seed_flat = tokenizer.flatten_tokens(seed_2d)
        
        print(f"Seed: {len(seed_2d)} events, {len(seed_flat)} tokens")
        
        # Generate continuation
        generated_tokens = generate(
            model, tokenizer,
            seed_tokens=seed_flat,
            max_events=400,
            temperature=0.85,
            top_k=40
        )
        
        output_path = OUTPUT_DIR / "generated_from_seed_lm.mid"
        tokenizer.tokens_to_midi_file(generated_tokens, output_path)
        print(f"Saved to: {output_path}")
        
        play_midi(output_path)
else:
    print(f"Seed file not found: {seed_midi_path}")
    print("Try a different file from midi_data/adl-piano-midi/")

Seed file not found: midi_data/adl-piano-midi/Classical/Classical/Johann Sebastian Bach/Aria, BWV 988.mid
Try a different file from midi_data/adl-piano-midi/


## 6. Explore Different Temperatures

In [None]:
temperatures = [0.5, 0.8, 1.0, 1.2]

for temp in temperatures:
    print(f"\nGenerating with temperature {temp}...")
    tokens = generate(model, tokenizer, max_events=300, temperature=temp)
    
    output_path = OUTPUT_DIR / f"generated_lm_temp_{temp}.mid"
    tokenizer.tokens_to_midi_file(tokens, output_path)
    print(f"  Saved: {output_path.name}")

print(f"\nAll files saved to: {OUTPUT_DIR}")


Generating with temperature 0.5...


## Tips

**Temperature guide:**
- `0.5` - Very coherent, conservative
- `0.8-0.9` - Balanced (recommended)
- `1.0` - More varied
- `1.2+` - Experimental, may be chaotic

**For better results:**
- Train for more epochs
- Use larger dataset
- Try different seeds from different genres