In [2]:
!pip install pretty_midi music21
!apt install -y fluidsynth

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
fluidsynth is already the newest version (2.2.5-1).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.


In [3]:
# LSTM Classical Music Generator with Optimizations (PyTorch)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pretty_midi
import numpy as np
import pandas as pd
import os, glob, random
from collections import defaultdict

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [4]:
!wget https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
!unzip maestro-v3.0.0-midi.zip -d maestro_dataset

--2025-08-02 03:57:15--  https://storage.googleapis.com/magentadata/datasets/maestro/v3.0.0/maestro-v3.0.0-midi.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.99.207, 142.250.107.207, 142.251.188.207, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.99.207|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 58416533 (56M) [application/octet-stream]
Saving to: ‘maestro-v3.0.0-midi.zip.1’


2025-08-02 03:57:15 (163 MB/s) - ‘maestro-v3.0.0-midi.zip.1’ saved [58416533/58416533]

Archive:  maestro-v3.0.0-midi.zip
replace maestro_dataset/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: maestro_dataset/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_08_R1_2004_01-02_ORIG_MID--AUDIO_08_R1_2004_01_Track01_wav.midi  
  inflating: maestro_dataset/maestro-v3.0.0/2004/MIDI-Unprocessed_XP_09_R1_2004_05_ORIG_MID--AUDIO_09_R1_2004

In [10]:
MIDI_DIR = "maestro_dataset/maestro-v3.0.0"

PITCHES = list(range(21, 109))
STEP_BINS = np.round(np.arange(0.05, 1.1, 0.05), 2)
DUR_BINS = np.round(np.arange(0.05, 1.5, 0.05), 2)

def quantize(x, bins):
    return bins[np.argmin(np.abs(bins - x))]

def parse_midi(file):
    try:
        pm = pretty_midi.PrettyMIDI(file)
    except Exception as e:
        print(f"Could not parse {file}: {e}")
        return []
    notes = []
    for inst in pm.instruments:
        if inst.is_drum: continue
        inst.notes.sort(key=lambda n: n.start)
        prev_start = 0
        for note in inst.notes:
            if note.pitch not in PITCHES:
                continue
            pitch = note.pitch
            step = quantize(note.start - prev_start, STEP_BINS)
            dur = quantize(note.end - note.start, DUR_BINS)
            notes.append((pitch, step, dur))
            prev_start = note.start
    return notes

def encode_notes(notes):
    encoded = []
    for pitch, step, dur in notes:
        pitch_id = PITCHES.index(pitch)
        step_id = np.where(STEP_BINS == step)[0][0]
        dur_id = np.where(DUR_BINS == dur)[0][0]
        encoded.append((pitch_id, step_id, dur_id))
    return encoded

class MusicDataset(Dataset):
    def __init__(self, files, seq_len=50):
        self.data = []
        for f in files:
            notes = parse_midi(f)
            if len(notes) < seq_len + 1:
                continue
            encoded = encode_notes(notes)
            for i in range(len(encoded) - seq_len):
                seq = encoded[i:i+seq_len]
                target = encoded[i+seq_len]
                self.data.append((seq, target))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq, target = self.data[idx]
        pitch_seq = torch.tensor([x[0] for x in seq], dtype=torch.long).to(device)
        step_seq  = torch.tensor([x[1] for x in seq], dtype=torch.long).to(device)
        dur_seq   = torch.tensor([x[2] for x in seq], dtype=torch.long).to(device)
        target    = torch.tensor(target, dtype=torch.long).to(device)
        return pitch_seq, step_seq, dur_seq, target

files = sorted(glob.glob(os.path.join(MIDI_DIR, "**/*.mid*"), recursive=True))[:300]
dataset = MusicDataset(files)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# LSTM with Embeddings

In [6]:
# 2. MODEL DEFINITION

class MusicLSTM(nn.Module):
    def __init__(self, pitch_size, step_size, dur_size, embed_dim=64, hidden_size=256):
        super().__init__()
        self.pitch_emb = nn.Embedding(pitch_size, embed_dim)
        self.step_emb  = nn.Embedding(step_size, embed_dim)
        self.dur_emb   = nn.Embedding(dur_size, embed_dim)

        self.lstm = nn.LSTM(embed_dim * 3, hidden_size, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.3)
        self.pitch_out = nn.Linear(hidden_size * 2, pitch_size)
        self.step_out  = nn.Linear(hidden_size * 2, step_size)
        self.dur_out   = nn.Linear(hidden_size * 2, dur_size)

    def forward(self, pitch, step, dur):
        x = torch.cat([self.pitch_emb(pitch), self.step_emb(step), self.dur_emb(dur)], dim=-1)
        x, _ = self.lstm(x)
        x = self.dropout(x[:, -1, :])
        return self.pitch_out(x), self.step_out(x), self.dur_out(x)

model = MusicLSTM(pitch_size=len(PITCHES), step_size=len(STEP_BINS), dur_size=len(DUR_BINS)).to(device)


In [11]:
# 3. TRAINING PROCESS

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    for pitch_seq, step_seq, dur_seq, target in dataloader:
        pitch_seq, step_seq, dur_seq = pitch_seq.to(device), step_seq.to(device), dur_seq.to(device)
        target = target.to(device)

        pitch_tgt, step_tgt, dur_tgt = target[:,0], target[:,1], target[:,2]

        optimizer.zero_grad()
        pitch_pred, step_pred, dur_pred = model(pitch_seq, step_seq, dur_seq)

        loss = criterion(pitch_pred, pitch_tgt) + \
               criterion(step_pred, step_tgt) + \
               criterion(dur_pred, dur_tgt)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

Epoch 1 Loss: 5.9354
Epoch 2 Loss: 5.6060
Epoch 3 Loss: 5.4827
Epoch 4 Loss: 5.4073
Epoch 5 Loss: 5.3561
Epoch 6 Loss: 5.3172
Epoch 7 Loss: 5.2862
Epoch 8 Loss: 5.2618
Epoch 9 Loss: 5.2429
Epoch 10 Loss: 5.2257


In [17]:
# 4. INFERENCE / DECODING

def sample(logits, temperature=1.0, top_k=5, top_p=0.9):
    logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)

    top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
    top_k_probs = top_k_probs / torch.sum(top_k_probs)

    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cum_probs = torch.cumsum(sorted_probs, dim=-1)
    cutoff = torch.where(cum_probs > top_p, True, False).float()
    cutoff_idx = torch.argmax(cutoff)
    nucleus_idx = sorted_idx[:cutoff_idx+1]
    nucleus_probs = probs[nucleus_idx]
    nucleus_probs = nucleus_probs / nucleus_probs.sum()

    idx = nucleus_idx[torch.multinomial(nucleus_probs, 1)]
    return idx.item()

def generate(model, seed, length=100):
    model.eval()
    generated = seed.copy()
    for _ in range(length):
        pitch_seq = torch.tensor([x[0] for x in generated[-50:]]).unsqueeze(0).to(device)
        step_seq  = torch.tensor([x[1] for x in generated[-50:]]).unsqueeze(0).to(device)
        dur_seq   = torch.tensor([x[2] for x in generated[-50:]]).unsqueeze(0).to(device)

        with torch.no_grad():
            pitch_logits, step_logits, dur_logits = model(pitch_seq, step_seq, dur_seq)

        pitch = sample(pitch_logits[0])
        step = sample(step_logits[0])
        dur = sample(dur_logits[0])
        generated.append((pitch, step, dur))
    return generated

In [18]:
# 5. OUTPUT TO MIDI

def decode(pitch_id, step_id, dur_id):
    return PITCHES[pitch_id], STEP_BINS[step_id], DUR_BINS[dur_id]

def to_midi(sequence, filename='output.mid'):
    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=0)
    time = 0
    for p, s, d in sequence:
        pitch, step, dur = decode(p, s, d)
        time += step
        inst.notes.append(pretty_midi.Note(velocity=100, pitch=pitch, start=time, end=time+dur))
    pm.instruments.append(inst)
    pm.write(filename)

seed = encode_notes(parse_midi(files[1]))[:50]
gen_seq = generate(model, seed)
to_midi(gen_seq, 'generated_output.midi')


# Method 1 for saving just the model's weights

In [19]:
model_save_path = 'music_lstm_model.pth'

torch.save(model.state_dict(), model_save_path)

print(f"Model state_dict saved to {model_save_path}")

Model state_dict saved to music_lstm_model.pth


In [20]:
import torch
import torch.nn as nn

loaded_model = MusicLSTM(
    pitch_size=len(PITCHES),
    step_size=len(STEP_BINS),
    dur_size=len(DUR_BINS),
    embed_dim=64,
    hidden_size=256
).to(device)

model_load_path = 'music_lstm_model.pth'
loaded_model.load_state_dict(torch.load(model_load_path, map_location=device))

loaded_model.eval()

print(f"Model loaded successfully from {model_load_path}")

Model loaded successfully from music_lstm_model.pth


# Method 2 For saving both model wieght and optimizer checkpoint

In [21]:
import torch

checkpoint = {
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': total_loss / len(dataloader), # or validation loss
    'pitch_size': len(PITCHES),
    'step_size': len(STEP_BINS),
    'dur_size': len(DUR_BINS),
    'embed_dim': 64,
    'hidden_size': 256
}

checkpoint_path = 'music_lstm_checkpoint.pth'
torch.save(checkpoint, checkpoint_path)
print(f"Checkpoint saved to {checkpoint_path}")

Checkpoint saved to music_lstm_checkpoint.pth


In [22]:
import torch
import torch.nn as nn

loaded_model = MusicLSTM(
    pitch_size=len(PITCHES),
    step_size=len(STEP_BINS),
    dur_size=len(DUR_BINS),
    embed_dim=64,
    hidden_size=256
).to(device)

optimizer = torch.optim.Adam(loaded_model.parameters(), lr=0.001)

checkpoint_path = 'music_lstm_checkpoint.pth'

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    loaded_model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    loaded_loss = checkpoint['loss']
    print(f"Checkpoint loaded. Resuming training from epoch {start_epoch}, previous loss: {loaded_loss:.4f}")
else:
    start_epoch = 0
    print("No checkpoint found. Starting training from epoch 0.")

Checkpoint loaded. Resuming training from epoch 10, previous loss: 5.2257


# Resumed Training

In [None]:
criterion = nn.CrossEntropyLoss()

for epoch in range(start_epoch, start_epoch + 10):
    loaded_model.train()
    total_loss = 0
    for pitch_seq, step_seq, dur_seq, target in dataloader:
        pitch_seq, step_seq, dur_seq = pitch_seq.to(device), step_seq.to(device), dur_seq.to(device)
        target = target.to(device)

        pitch_tgt, step_tgt, dur_tgt = target[:,0], target[:,1], target[:,2]

        optimizer.zero_grad()
        pitch_pred, step_pred, dur_pred = loaded_model(pitch_seq, step_seq, dur_seq)

        loss = criterion(pitch_pred, pitch_tgt) + criterion(step_pred, step_tgt) + criterion(dur_pred, dur_tgt)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

# Improved Architecture

In [None]:
# Improving the model

import torch
import torch.nn as nn

class MusicLSTM_Advanced(nn.Module):
    def __init__(self, pitch_size, step_size, dur_size, embed_dim=128, hidden_size=512, num_layers=3, dropout=0.3):
        super().__init__()

        self.pitch_emb = nn.Embedding(pitch_size, embed_dim)
        self.step_emb  = nn.Embedding(step_size, embed_dim)
        self.dur_emb   = nn.Embedding(dur_size, embed_dim)

        # LSTM
        self.lstm = nn.LSTM(
            input_size=embed_dim * 3,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=True
        )

        # Self-Attention: simple dot-product attention over time
        self.attn = nn.MultiheadAttention(embed_dim=hidden_size * 2, num_heads=4, batch_first=True)
        self.attn_norm = nn.LayerNorm(hidden_size * 2)

        # Output heads
        self.pitch_out = nn.Linear(hidden_size * 2, pitch_size)
        self.step_out  = nn.Linear(hidden_size * 2, step_size)
        self.dur_out   = nn.Linear(hidden_size * 2, dur_size)

    def forward(self, pitch, step, dur):
        # Embed & concat: [B, T, 3*embed_dim]
        x = torch.cat([
            self.pitch_emb(pitch),
            self.step_emb(step),
            self.dur_emb(dur)
        ], dim=-1)

        # LSTM: [B, T, 2*H]
        x, _ = self.lstm(x)

        # Self-Attention over time steps
        attn_output, _ = self.attn(x, x, x)  # Query = Key = Value = LSTM output
        x = self.attn_norm(attn_output + x)  # Residual + norm

        # Mean pooling over sequence length
        x = x.mean(dim=1)  # [B, 2*H]

        # Output heads
        return self.pitch_out(x), self.step_out(x), self.dur_out(x)


In [None]:
# 3. TRAINING PROCESS


model = MusicLSTM_Advanced(pitch_size=len(PITCHES), step_size=len(STEP_BINS), dur_size=len(DUR_BINS)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(10):
    model.train()
    total_loss = 0
    for pitch_seq, step_seq, dur_seq, target in dataloader:
        pitch_seq, step_seq, dur_seq = pitch_seq.to(device), step_seq.to(device), dur_seq.to(device)
        target = target.to(device)

        pitch_tgt, step_tgt, dur_tgt = target[:,0], target[:,1], target[:,2]

        optimizer.zero_grad()
        pitch_pred, step_pred, dur_pred = model(pitch_seq, step_seq, dur_seq)

        loss = criterion(pitch_pred, pitch_tgt) + criterion(step_pred, step_tgt) + criterion(dur_pred, dur_tgt)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {total_loss / len(dataloader):.4f}")

In [None]:
# ---------------------------
# 4. INFERENCE / DECODING
# ---------------------------

def sample(logits, temperature=1.0, top_k=5, top_p=0.9):
    logits = logits / temperature
    probs = torch.softmax(logits, dim=-1)

    # Top-k
    top_k_probs, top_k_indices = torch.topk(probs, k=top_k)
    top_k_probs = top_k_probs / torch.sum(top_k_probs)

    # Nucleus sampling
    sorted_probs, sorted_idx = torch.sort(probs, descending=True)
    cum_probs = torch.cumsum(sorted_probs, dim=-1)
    cutoff = torch.where(cum_probs > top_p, True, False).float()
    cutoff_idx = torch.argmax(cutoff)
    nucleus_idx = sorted_idx[:cutoff_idx+1]
    nucleus_probs = probs[nucleus_idx]
    nucleus_probs = nucleus_probs / nucleus_probs.sum()

    idx = nucleus_idx[torch.multinomial(nucleus_probs, 1)]
    return idx.item()

def generate(model, seed, length=100):
    model.eval()
    generated = seed.copy()
    for _ in range(length):
        pitch_seq = torch.tensor([x[0] for x in generated[-50:]]).unsqueeze(0).to(device)
        step_seq  = torch.tensor([x[1] for x in generated[-50:]]).unsqueeze(0).to(device)
        dur_seq   = torch.tensor([x[2] for x in generated[-50:]]).unsqueeze(0).to(device)

        with torch.no_grad():
            pitch_logits, step_logits, dur_logits = model(pitch_seq, step_seq, dur_seq)

        pitch = sample(pitch_logits[0])
        step = sample(step_logits[0])
        dur = sample(dur_logits[0])
        generated.append((pitch, step, dur))
    return generated

In [None]:
# ---------------------------
# 5. OUTPUT TO MIDI
# ---------------------------

def decode(pitch_id, step_id, dur_id):
    return PITCHES[pitch_id], STEP_BINS[step_id], DUR_BINS[dur_id]

def to_midi(sequence, filename='output.mid'):
    pm = pretty_midi.PrettyMIDI()
    inst = pretty_midi.Instrument(program=0)
    time = 0
    for p, s, d in sequence:
        pitch, step, dur = decode(p, s, d)
        time += step
        inst.notes.append(pretty_midi.Note(velocity=100, pitch=pitch, start=time, end=time+dur))
    pm.instruments.append(inst)
    pm.write(filename)

# Example
seed = encode_notes(parse_midi(files[1]))[:50]
gen_seq = generate(model, seed)
to_midi(gen_seq, 'improved_generated_output.midi')
