# MIDI Music Generation - Training

Train a deep learning model for polyphonic piano music generation.

**Models:** Mamba (recommended) or LSTM  
**Dataset:** ADL Piano MIDI  
**Output:** `midi_data/best_model.pt`

## 1. Setup

In [None]:
import random
import time
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from midi_utils import (
    MIDI_CONFIG, NUM_PITCHES, SEQUENCE_LENGTH, STRIDE, SEED, POS_WEIGHT,
    get_device, midi_to_piano_roll,
    MusicMamba, MusicLSTM, MidiDataset
)

# Reproducibility
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = get_device()
print(f"PyTorch version: {torch.__version__}")

## 2. Load Dataset

In [None]:
DATA_DIR = Path("./midi_data")
ADL_DIR = DATA_DIR / "adl-piano-midi"

# Filter to a single genre for better learning
# Options: "Classical", "Jazz", "Rock", "Pop", etc. (or None for all)
GENRE_FILTER = "Classical"  # Focus on one style

def load_adl_piano_midi(adl_dir: Path, genre_filter: str = None) -> list:
    """Load ADL Piano MIDI dataset, optionally filtered by genre."""
    files = []
    for midi_path in adl_dir.rglob("*.mid"):
        parts = midi_path.relative_to(adl_dir).parts
        genre = parts[0] if len(parts) >= 1 else "Unknown"
        if genre_filter is None or genre == genre_filter:
            files.append({"path": midi_path, "genre": genre})
    return files

def create_splits(files: list, train_ratio: float = 0.8, val_ratio: float = 0.1) -> dict:
    """Create train/validation/test splits."""
    random.seed(SEED)
    shuffled = files.copy()
    random.shuffle(shuffled)
    n = len(shuffled)
    train_end = int(n * train_ratio)
    val_end = int(n * (train_ratio + val_ratio))
    return {
        "train": shuffled[:train_end],
        "validation": shuffled[train_end:val_end],
        "test": shuffled[val_end:]
    }

all_files = load_adl_piano_midi(ADL_DIR, genre_filter=GENRE_FILTER)
print(f"Genre filter: {GENRE_FILTER or 'None (all genres)'}")
print(f"Found {len(all_files)} MIDI files")

adl_splits = create_splits(all_files)
for split, files in adl_splits.items():
    print(f"  {split}: {len(files)} files")

In [None]:
# Use full filtered dataset (no limit needed with single genre)
MAX_FILES_PER_SPLIT = None

if MAX_FILES_PER_SPLIT:
    for split in adl_splits:
        adl_splits[split] = adl_splits[split][:MAX_FILES_PER_SPLIT]
    print(f"Limited to {MAX_FILES_PER_SPLIT} files per split")
else:
    print(f"Using full {GENRE_FILTER or 'all'} dataset")

## 3. Process MIDI Files

In [None]:
piano_rolls_by_split = {"train": [], "validation": [], "test": []}

for split_name, files in adl_splits.items():
    print(f"Processing {split_name}...")
    for file_info in tqdm(files, desc=f"  {split_name}"):
        roll = midi_to_piano_roll(file_info["path"], MIDI_CONFIG)
        if roll is not None and len(roll) > 100:
            piano_rolls_by_split[split_name].append(roll)

print("\nProcessed piano rolls:")
for split_name, rolls in piano_rolls_by_split.items():
    total_frames = sum(len(r) for r in rolls)
    duration_min = total_frames / MIDI_CONFIG["fs"] / 60
    print(f"  {split_name}: {len(rolls)} files, {duration_min:.1f} min")

In [None]:
# Visualize sample
if piano_rolls_by_split["train"]:
    sample = piano_rolls_by_split["train"][0][:200]
    plt.figure(figsize=(15, 5))
    plt.imshow(sample.T, aspect='auto', origin='lower', cmap='magma')
    plt.xlabel('Time (frames)')
    plt.ylabel('Pitch')
    plt.title('Training Sample - Piano Roll')
    plt.colorbar(label='Velocity')
    plt.show()

## 4. Create DataLoaders

In [None]:
BATCH_SIZE = 64 if DEVICE.type in ('mps', 'cuda') else 32
NUM_WORKERS = 0 if DEVICE.type == 'mps' else 4

train_dataset = MidiDataset(piano_rolls_by_split["train"], SEQUENCE_LENGTH, STRIDE)
val_dataset = MidiDataset(piano_rolls_by_split["validation"], SEQUENCE_LENGTH, STRIDE)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

print(f"Train: {len(train_dataset):,} sequences ({len(train_loader)} batches)")
print(f"Val:   {len(val_dataset):,} sequences ({len(val_loader)} batches)")

## 5. Initialize Model

In [None]:
MODEL_TYPE = "mamba"  # "mamba" or "lstm"

if MODEL_TYPE == "mamba":
    model = MusicMamba(
        input_size=NUM_PITCHES,
        d_model=256,
        d_state=16,
        n_layers=4,
        dropout=0.3  # Increased from 0.1 to reduce overfitting
    ).to(DEVICE)
else:
    model = MusicLSTM(
        input_size=NUM_PITCHES,
        hidden_size=512,
        num_layers=3,
        dropout=0.4  # Increased from 0.3
    ).to(DEVICE)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model: {MODEL_TYPE.upper()}")
print(f"Parameters: {total_params:,}")
print(f"Device: {DEVICE}")

## 6. Training

In [None]:
# Training config
NUM_EPOCHS = 50        # Max epochs (early stopping will likely trigger before)
LEARNING_RATE = 0.001
GRAD_CLIP = 1.0
EARLY_STOPPING_PATIENCE = 10

# IMPORTANT: Use weighted loss to handle class imbalance!
# Piano rolls are ~95% zeros, so we weight positive (note-on) samples more heavily
from midi_utils import POS_WEIGHT
pos_weight = torch.full((NUM_PITCHES,), POS_WEIGHT, device=DEVICE)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# Increased weight_decay from 1e-5 to 1e-4 to reduce overfitting
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

# Disable AMP for MPS+Mamba (numerical stability)
USE_AMP = DEVICE.type == 'cuda'
scaler = torch.cuda.amp.GradScaler() if USE_AMP else None

print(f"Max Epochs: {NUM_EPOCHS}")
print(f"Early Stopping: patience={EARLY_STOPPING_PATIENCE}")
print(f"Positive Weight: {POS_WEIGHT} (handles class imbalance)")
print(f"Weight Decay: 1e-4 (regularization)")
print(f"Mixed Precision: {USE_AMP}")

In [None]:
def train_epoch(model, loader, criterion, optimizer, device, grad_clip, use_amp=False, scaler=None):
    model.train()
    total_loss = 0.0
    for batch_x, batch_y in loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        
        if use_amp and device.type == 'cuda':
            with torch.cuda.amp.autocast():
                output, _ = model(batch_x)
                loss = criterion(output, batch_y)
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            output, _ = model(batch_x)
            loss = criterion(output, batch_y)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
        
        total_loss += loss.item()
    return total_loss / len(loader)


def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch_x, batch_y in loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            output, _ = model(batch_x)
            loss = criterion(output, batch_y)
            total_loss += loss.item()
    return total_loss / len(loader)

In [None]:
# Training loop with early stopping
train_losses, val_losses = [], []
best_val_loss = float('inf')
epochs_without_improvement = 0
best_model_path = DATA_DIR / "best_model.pt"
DATA_DIR.mkdir(parents=True, exist_ok=True)

print(f"Training {MODEL_TYPE.upper()} on {DEVICE}...")
print(f"Saving best model to: {best_model_path}\n")

total_start = time.time()

for epoch in range(NUM_EPOCHS):
    epoch_start = time.time()
    
    train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE, GRAD_CLIP, USE_AMP, scaler)
    val_loss = validate(model, val_loader, criterion, DEVICE)
    
    scheduler.step(val_loss)
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    
    # Check for improvement
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        epochs_without_improvement = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'model_type': MODEL_TYPE,
        }, best_model_path)
        marker = " *"  # Mark best epoch
    else:
        epochs_without_improvement += 1
        marker = ""
    
    # Log progress
    lr = optimizer.param_groups[0]['lr']
    elapsed = time.time() - epoch_start
    print(f"Epoch {epoch+1:3d}/{NUM_EPOCHS} | Train: {train_loss:.4f} | Val: {val_loss:.4f} | LR: {lr:.6f} | {elapsed:.1f}s{marker}")
    
    # Early stopping
    if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
        print(f"\nEarly stopping! No improvement for {EARLY_STOPPING_PATIENCE} epochs.")
        break

total_time = time.time() - total_start
print(f"\nDone! Total: {total_time/60:.1f} min, Best val loss: {best_val_loss:.4f} (epoch {len(train_losses) - epochs_without_improvement})")

In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train')
plt.plot(val_losses, label='Validation')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"\nModel saved to: {best_model_path}")
print("Use midi_generation.ipynb to generate music!")