In [1]:
!pip install scikit-learn



In [1]:
import pretty_midi
import os
from collections import defaultdict
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
from miditok.pytorch_data import DatasetMIDI, DataCollator
import datetime
from collections import Counter
from sklearn.model_selection import train_test_split

In [2]:
def print_progress_bar(iteration, total, prefix='', length=50):
    percent = ("{0:.1f}").format(100 * (iteration / float(total)))
    filled_length = int(length * iteration // total)
    bar = '█' * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% Complete', end='\r', flush=True)
    if iteration == total:
        print()

In [3]:
def valid_midi_files(filepaths):
    valid_files = []
    for i, filepath in enumerate(filepaths):
        try:
            midi_data = pretty_midi.PrettyMIDI(filepath)
            if len(midi_data.instruments) > 0:
                valid_files.append(filepath)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        print_progress_bar(i+1, len(filepaths), prefix='Validating MIDI files')
    return valid_files

In [4]:
midi_dirpath = 'nesmdb_midi/'
midi_train_dirpath = os.path.join(midi_dirpath, 'train')
midi_test_dirpath = os.path.join(midi_dirpath, 'test')
midi_val_dirpath = os.path.join(midi_dirpath, 'valid')
midi_train_filesnames = os.listdir(midi_train_dirpath)
midi_test_filesnames = os.listdir(midi_test_dirpath)
midi_val_filenames = os.listdir(midi_val_dirpath)

midi_train_filepaths = valid_midi_files([os.path.join(midi_train_dirpath, filename) for filename in midi_train_filesnames])
midi_test_filepaths = valid_midi_files([os.path.join(midi_test_dirpath, filename) for filename in midi_test_filesnames])
midi_val_filepaths = valid_midi_files([os.path.join(midi_val_dirpath, filename) for filename in midi_val_filenames])
all_filepaths = midi_train_filepaths + midi_test_filepaths + midi_val_filepaths

Error processing nesmdb_midi/train/064_DeepDungeonII_YuushinoMonshou_09_10DungeonFloor3.mid: MIDI file has a largest tick of 59777468, it is likely corrupt
Error processing nesmdb_midi/train/090_DragonWarriorIII_30_31IntoTheLegend.mid: MIDI file has a largest tick of 16605640, it is likely corrupt
Error processing nesmdb_midi/train/091_DragonWarriorIV_43_44FinaleGuidingPeople.mid: MIDI file has a largest tick of 25710905, it is likely corrupt
Error processing nesmdb_midi/train/104_FamicomJumpII_Saikyono7_nin_19_20ThemeofFriendshipEffortVictoryCreditRoll.mid: MIDI file has a largest tick of 12944249, it is likely corrupt
Error processing nesmdb_midi/train/117_FinalFantasy_17_18EndTheme.mid: MIDI file has a largest tick of 11324494, it is likely corrupt
Error processing nesmdb_midi/train/122_FireEmblem_AnkokuRyutoHikarinoTsurugi_28_29EndingOmnibusBallad.mid: MIDI file has a largest tick of 17014635, it is likely corrupt
Error processing nesmdb_midi/train/122_FireEmblem_AnkokuRyutoHikarin

In [5]:
TIME_SHIFT_RESOLUTION = 0.01  # 50 ms
MAX_SHIFT_STEPS = 100  # Max 5 seconds
BEGINNING_OF_SONG_TOKEN = '<BOS>'
END_OF_SONG_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'
VOCABULARY = dict()
index = 0

for special_token in [PAD_TOKEN, BEGINNING_OF_SONG_TOKEN, END_OF_SONG_TOKEN]:
    VOCABULARY[special_token] = index
    index += 1

for time_shift in range(1, MAX_SHIFT_STEPS + 1):
    VOCABULARY[f'time_shift_{time_shift}'] = index
    index += 1

for action in ["note_on", "note_off"]:
    for pitch in range(128):
        for program in [80, 81, 38, 121]:
            VOCABULARY[f'{action}_{pitch}_instrument_{program}'] = index
            index += 1

print(f'Vocabulary size: {len(VOCABULARY)}')
print(f'Vocabulary: {VOCABULARY}')

Vocabulary size: 1127
Vocabulary: {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, 'time_shift_1': 3, 'time_shift_2': 4, 'time_shift_3': 5, 'time_shift_4': 6, 'time_shift_5': 7, 'time_shift_6': 8, 'time_shift_7': 9, 'time_shift_8': 10, 'time_shift_9': 11, 'time_shift_10': 12, 'time_shift_11': 13, 'time_shift_12': 14, 'time_shift_13': 15, 'time_shift_14': 16, 'time_shift_15': 17, 'time_shift_16': 18, 'time_shift_17': 19, 'time_shift_18': 20, 'time_shift_19': 21, 'time_shift_20': 22, 'time_shift_21': 23, 'time_shift_22': 24, 'time_shift_23': 25, 'time_shift_24': 26, 'time_shift_25': 27, 'time_shift_26': 28, 'time_shift_27': 29, 'time_shift_28': 30, 'time_shift_29': 31, 'time_shift_30': 32, 'time_shift_31': 33, 'time_shift_32': 34, 'time_shift_33': 35, 'time_shift_34': 36, 'time_shift_35': 37, 'time_shift_36': 38, 'time_shift_37': 39, 'time_shift_38': 40, 'time_shift_39': 41, 'time_shift_40': 42, 'time_shift_41': 43, 'time_shift_42': 44, 'time_shift_43': 45, 'time_shift_44': 46, 'time_shift_45': 47, 

In [6]:
ID_TO_TOKEN = {v: k for k, v in VOCABULARY.items()}

In [7]:
def midi_to_tokens(pm: pretty_midi.PrettyMIDI):
    events = []

    for instrument in pm.instruments:
        for note in instrument.notes:
            events.append((note.start, f'note_on_{note.pitch}_instrument_{instrument.program}'))
            events.append((note.end, f'note_off_{note.pitch}_instrument_{instrument.program}'))
    
    events.sort()  # Sort by time

    tokens = []
    last_time = 0.0
    for time, event in events:
        delta = time - last_time
        steps = round(delta / TIME_SHIFT_RESOLUTION)

        while steps > 0:
            shift = min(steps, MAX_SHIFT_STEPS)
            tokens.append(f'time_shift_{shift}')
            steps -= shift
        
        tokens.append(event)
        last_time = time
    return [BEGINNING_OF_SONG_TOKEN] + tokens + [END_OF_SONG_TOKEN]

def tokens_to_midi(tokens):
    pm = pretty_midi.PrettyMIDI()
    instruments = dict()
    active_notes = dict()

    current_time = 0.0
    for token in tokens:
        if token.startswith('time_shift_'):
            shift_steps = int(token.split('_')[-1])
            current_time += shift_steps * TIME_SHIFT_RESOLUTION
        elif token.startswith('note_on_'):
            pitch = int(token.split('_')[2])
            instrument = int(token.split('_')[-1])
            active_notes[(pitch, instrument)] = current_time
        elif token.startswith('note_off_'):
            pitch = int(token.split('_')[2])
            instrument = int(token.split('_')[-1])
            if (pitch, instrument) not in active_notes:
                print(f"Warning: Note off for {pitch} on instrument {instrument} without matching note on.")
                continue
            start_time = active_notes[(pitch, instrument)]

            if instrument not in instruments:
                instruments[instrument] = pretty_midi.Instrument(program=instrument)

            note = pretty_midi.Note(
                velocity=100, pitch=pitch, start=start_time, end=current_time
            )

            instruments[instrument].notes.append(note)
            
    for instrument in instruments.values():
        pm.instruments.append(instrument)
    
    return pm

In [8]:
midi = pretty_midi.PrettyMIDI(all_filepaths[0])
tokens = midi_to_tokens(midi)
print(tokens[:30])
midi_reconstructed = tokens_to_midi(tokens)
# midi_reconstructed.write('reconstructed_midi.mid')
# midi.write('original_midi.mid')

['<BOS>', 'time_shift_1', 'note_on_62_instrument_38', 'note_on_65_instrument_80', 'time_shift_39', 'note_off_65_instrument_80', 'note_on_69_instrument_80', 'note_off_62_instrument_38', 'note_on_65_instrument_38', 'time_shift_13', 'note_off_69_instrument_80', 'note_on_65_instrument_80', 'note_off_65_instrument_38', 'note_on_62_instrument_38', 'time_shift_13', 'note_off_65_instrument_80', 'note_on_69_instrument_80', 'note_off_62_instrument_38', 'note_on_65_instrument_38', 'time_shift_13', 'note_off_69_instrument_80', 'note_on_70_instrument_80', 'note_off_65_instrument_38', 'note_on_67_instrument_38', 'time_shift_40', 'note_off_70_instrument_80', 'note_on_74_instrument_80', 'note_off_67_instrument_38', 'note_on_70_instrument_38', 'time_shift_13']


In [9]:
def load_sequences(filepaths):
    sequences = []
    for i, filepath in enumerate(filepaths):
        pm = pretty_midi.PrettyMIDI(filepath)
        tokens = midi_to_tokens(pm)
        if not tokens:
            raise ValueError(f'No tokens generated for {filepath}')
        sequences.append([VOCABULARY[token] for token in tokens])
        print_progress_bar(i+1, len(filepaths), prefix='Loading sequences')
    return sequences

class MIDITokenDataset(Dataset):
    def __init__(self, sequences, seq_length=512):
        self.inputs = []
        self.targets = []
        for ind, seq in enumerate(sequences):
            num_chunks = len(seq) // (seq_length + 1)
            for chunk in range(num_chunks + 1):
                chunk_start = chunk * (seq_length + 1)
                chunk_end = chunk_start + seq_length + 1
                chunk = seq[chunk_start:chunk_end]

                if len(chunk) < seq_length / 4:
                    continue

                input = chunk[:-1]
                input = np.pad(input, (0, seq_length - len(input)), constant_values=VOCABULARY[PAD_TOKEN])
                
                target = chunk[1:]
                target = np.pad(target, (0, seq_length - len(target)), constant_values=VOCABULARY[PAD_TOKEN])
                
                self.inputs.append(torch.tensor(input, dtype=torch.long))
                self.targets.append(torch.tensor(target, dtype=torch.long))
            
            print_progress_bar(ind+1, len(sequences), prefix='Processing sequences')

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [10]:
# Load and convert
train_sequences = load_sequences(midi_train_filepaths)
val_sequences = load_sequences(midi_val_filepaths)

Loading sequences |██████████████████████████████████████████████████| 100.0% Complete
Loading sequences |██████████████████████████████████████████████████| 100.0% Complete


In [11]:

def print_sequence_percentiles(sequences, prefix=''):
    print(f'{prefix} sequences: {len(sequences)}')
    print(f"{prefix} 90th percile length: {np.percentile([len(seq) for seq in sequences], 90)}")
    print(f"{prefix} 50th percile length: {np.percentile([len(seq) for seq in sequences], 50)}")
    print(f"{prefix} 25th percile length: {np.percentile([len(seq) for seq in sequences], 25)}")
    print(f"{prefix} 10th percile length: {np.percentile([len(seq) for seq in sequences], 10)}")
    print(f"{prefix} max train sequence length: {max(len(seq) for seq in sequences)}")
    print(f"{prefix} min train sequence length: {min(len(seq) for seq in sequences)}")

print_sequence_percentiles(train_sequences, prefix='Train')
print_sequence_percentiles(val_sequences, prefix='Validation')

Train sequences: 4470
Train 90th percile length: 4361.699999999999
Train 50th percile length: 944.0
Train 25th percile length: 286.25
Train 10th percile length: 99.0
Train max train sequence length: 27506
Train min train sequence length: 7
Validation sequences: 402
Validation 90th percile length: 4329.300000000001
Validation 50th percile length: 682.0
Validation 25th percile length: 176.0
Validation 10th percile length: 73.0
Validation max train sequence length: 14466
Validation min train sequence length: 14


In [12]:
# Create datasets
train_dataset = MIDITokenDataset(train_sequences, seq_length=512)
val_dataset = MIDITokenDataset(val_sequences, seq_length=512)
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(val_dataset)}')

Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Train dataset size: 16287
Validation dataset size: 1328


In [13]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(MusicRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch_size, seq_length)
        x = self.embedding(x)  # (batch_size, seq_length, embedding_dim)
        out, hidden = self.rnn(x, hidden)  # out: (batch_size, seq_length, hidden_dim)
        out = self.fc(out)  # (batch_size, seq_length, vocab_size)
        return out, hidden

In [14]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=20, lr=0.001, device='cuda'):
    time_start = time.time()
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # --------- Training ---------
        model.train()
        total_train_loss = 0

        for i, batch in enumerate(train_loader):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            outputs, _ = model(inputs)

            outputs = outputs.reshape(-1, vocab_size)
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            print_progress_bar(i+1, len(train_loader), prefix='Training...')

        avg_train_loss = total_train_loss / len(train_loader)

        # --------- Validation ---------
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                inputs, targets = batch
                inputs = inputs.to(device)
                targets = targets.to(device)

                outputs, _ = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)

                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

                print_progress_bar(i+1, len(val_loader), prefix='Validating...')

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [18]:
batch_size = 32  # You can adjust this depending on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MusicRNN(vocab_size=len(VOCABULARY), embedding_dim=256, hidden_dim=256, num_layers=2)
train(model, train_loader, val_loader, vocab_size=len(VOCABULARY), device=device)

Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 1/20 | Train Loss: 6.9929 | Val Loss: 6.8670
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 2/20 | Train Loss: 6.8542 | Val Loss: 6.6067
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 3/20 | Train Loss: 6.5978 | Val Loss: 6.2484
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 4/20 | Train Loss: 6.1302 | Val Loss: 5.6940
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |████████████████████████████████████████████

In [13]:
def sample(model, start_token, max_length=100, temperature=1.0, device='cuda'):
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)  # (1, 1)

    hidden = None

    for _ in range(max_length):
        output, hidden = model(input_token, hidden)  # output: (1, 1, vocab_size)
        output = output[:, -1, :]  # take the last output
        output = output / temperature  # adjust randomness

        probs = F.softmax(output, dim=-1)  # (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        if next_token == VOCABULARY[END_OF_SONG_TOKEN] or VOCABULARY[PAD_TOKEN]: # reach end of sequence
          break

        input_token = torch.tensor([[next_token]], device=device)

    return generated



In [17]:
start_token = VOCABULARY[BEGINNING_OF_SONG_TOKEN]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generated_sequence = sample(model, start_token, max_length=1024, device=device)
generated_tokens = [ID_TO_TOKEN[token] for token in generated_sequence]
tokens_reconstructed = tokens_to_midi(generated_tokens)
tokens_reconstructed.write('generated_midi.mid')



# Transformer


In [14]:
class MusicTransformer(nn.Module):
    def __init__(
        self, 
        vocab_size, 
        d_model=512,
        nhead=8,
        num_layers=4,
        dim_feedforward=2048,
        dropout=0.1,
        max_seq_length=2048
    ):
        super(MusicTransformer, self).__init__()
        
        # Token embeddings
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_seq_length, d_model)
        
        # Transformer decoder layers (each has multi-head attention + MLP)
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu',
            batch_first=True,
            norm_first=True  # Pre-LN for better stability
        )
        
        self.transformer_decoder = nn.TransformerDecoder(
            decoder_layer,
            num_layers=num_layers
        )
        
        # Output projection
        self.fc = nn.Linear(d_model, vocab_size)
        
        # Initialize weights
        self.init_weights()
        
        self.d_model = d_model
        self.max_seq_length = max_seq_length
        
    def init_weights(self):
        # Xavier uniform initialization
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()
        self.fc.weight.data.uniform_(-initrange, initrange)
        
    def generate_square_subsequent_mask(self, sz):
        """Generate a mask to prevent attention to future positions"""
        mask = torch.triu(torch.ones(sz, sz), diagonal=1)
        mask = mask.masked_fill(mask == 1, float('-inf'))
        return mask
    
    def forward(self, x, memory=None):
        # x: (batch_size, seq_length)
        batch_size, seq_length = x.shape
        
        # Token embeddings
        token_embeddings = self.embedding(x) * (self.d_model ** 0.5)
        
        # Positional embeddings
        positions = torch.arange(seq_length, device=x.device).unsqueeze(0)
        pos_embeddings = self.pos_embedding(positions)
        
        # Combine embeddings
        embeddings = token_embeddings + pos_embeddings
        
        # Create causal mask
        tgt_mask = self.generate_square_subsequent_mask(seq_length).to(x.device)
        
        # Pass through transformer decoder
        # In decoder-only mode, we use the embeddings as both tgt and memory
        output = self.transformer_decoder(
            tgt=embeddings,
            memory=embeddings,  # Self-attention only
            tgt_mask=tgt_mask,
            memory_mask=tgt_mask
        )
        
        # Project to vocabulary
        output = self.fc(output)
        
        return output, None  # Return None for hidden state compatibility


In [15]:
import math
from torch.utils.data import DataLoader
import torch, torch.nn as nn, torch.optim as optim

def accuracy_from_logits(logits, targets):
    """Compute token-level accuracy."""
    preds = logits.argmax(dim=-1)
    correct = (preds == targets).sum().item()
    total   = targets.numel()
    return correct / total

def train(
    model, train_loader, val_loader,
    vocab_size, num_epochs=10, lr=1e-3, device='cuda'):

    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # ---------- Training ----------
        model.train()
        tr_loss = tr_tokens = tr_correct = 0

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)

            optimizer.zero_grad()
            logits, _ = model(inputs)              # (B, T, V)
            loss = criterion(logits.view(-1, vocab_size),
                             targets.view(-1))
            loss.backward()
            optimizer.step()

            tr_loss     += loss.item()
            tr_correct  += (logits.argmax(-1) == targets).sum().item()
            tr_tokens   += targets.numel()

        avg_tr_loss = tr_loss / len(train_loader)
        tr_ppl      = math.exp(avg_tr_loss)
        tr_acc      = tr_correct / tr_tokens

        # ---------- Validation ----------
        model.eval()
        val_loss = val_tokens = val_correct = 0

        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                logits, _ = model(inputs)
                loss = criterion(logits.view(-1, vocab_size),
                                 targets.view(-1))
                val_loss    += loss.item()
                val_correct += (logits.argmax(-1) == targets).sum().item()
                val_tokens  += targets.numel()

        avg_val_loss = val_loss / len(val_loader)
        val_ppl      = math.exp(avg_val_loss)
        val_acc      = val_correct / val_tokens

        print(f"E{epoch+1:02d} train | loss {avg_tr_loss:.4f} | "
              f"ppl {tr_ppl:6.2f} | acc {tr_acc:.3f}")
        print(f"E{epoch+1:02d} val   | loss {avg_val_loss:.4f} | "
              f"ppl {val_ppl:6.2f} | acc {val_acc:.3f}\n")

    return model


# Model initialization with correct parameters
batch_size = 32  # You can adjust this depending on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')

# Initialize MusicTransformer with correct parameters
model = MusicTransformer(
    vocab_size=len(VOCABULARY),
    d_model=512,           # Model dimension (replaces embedding_dim)
    nhead=8,               # Number of attention heads
    num_layers=3,          # Number of transformer layers
    dim_feedforward=2048,  # MLP hidden dimension
    dropout=0.1,           # Dropout rate
    max_seq_length=512    # Maximum sequence length
)

# Train the model
trained_model = train(model, train_loader, val_loader, vocab_size=len(VOCABULARY), device=device)

Using device: cuda
E01 train | loss 3.2707 | ppl  26.33 | acc 0.319
E01 val   | loss 3.0130 | ppl  20.35 | acc 0.302

E02 train | loss 1.9959 | ppl   7.36 | acc 0.495
E02 val   | loss 1.8509 | ppl   6.37 | acc 0.539

E03 train | loss 1.3613 | ppl   3.90 | acc 0.638
E03 val   | loss 1.5033 | ppl   4.50 | acc 0.620

E04 train | loss 1.1574 | ppl   3.18 | acc 0.687
E04 val   | loss 1.4033 | ppl   4.07 | acc 0.645

E05 train | loss 1.0615 | ppl   2.89 | acc 0.710
E05 val   | loss 1.3593 | ppl   3.89 | acc 0.656

E06 train | loss 1.0003 | ppl   2.72 | acc 0.725
E06 val   | loss 1.3314 | ppl   3.79 | acc 0.664

E07 train | loss 0.9558 | ppl   2.60 | acc 0.735
E07 val   | loss 1.3250 | ppl   3.76 | acc 0.667

E08 train | loss 0.9206 | ppl   2.51 | acc 0.743
E08 val   | loss 1.3136 | ppl   3.72 | acc 0.670

E09 train | loss 0.8900 | ppl   2.44 | acc 0.750
E09 val   | loss 1.3130 | ppl   3.72 | acc 0.671

E10 train | loss 0.8662 | ppl   2.38 | acc 0.756
E10 val   | loss 1.3236 | ppl   3.76 | ac

In [16]:
import torch
import torch.nn.functional as F

def sample(
    model,
    start_token: int,
    max_length: 512,
    temperature: float = 1.0,
    device: str = "cuda",
):
    """
    Autoregressively sample from a trained MusicTransformer.

    Args
    ----
    model : nn.Module
        Trained MusicTransformer instance.
    start_token : int
        Vocabulary id that marks the beginning-of-song.
    max_length : int
        Maximum number of tokens (including the start_token).
    temperature : float
        Softmax temperature (>1 = more random, <1 = more greedy).
    device : str
        'cuda', 'mps', or 'cpu'.

    Returns
    -------
    List[int] : generated token ids (including the start_token).
    """
    model.to(device)
    model.eval()

    # running sequence buffer (1, t)
    generated = [start_token]
    input_tokens = torch.tensor([generated], dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_length - 1):          # we already have one token
            logits, _ = model(input_tokens)      # (1, t, vocab)
            logits = logits[:, -1, :]            # last step (1, vocab)
            logits = logits / temperature

            probs = F.softmax(logits, dim=-1)    # (1, vocab)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)

            # stop if EOS or PAD
            if next_token in (
                VOCABULARY[END_OF_SONG_TOKEN],
                VOCABULARY[PAD_TOKEN],
            ):
                break

            # append and continue
            input_tokens = torch.tensor(
                [generated], dtype=torch.long, device=device
            )

    return generated
start_token = VOCABULARY[BEGINNING_OF_SONG_TOKEN]
device = 'cuda' if torch.cuda.is_available() else 'cpu'

generated_ids = sample(trained_model, start_token, max_length=512, device=device)
generated_tokens = [ID_TO_TOKEN[i] for i in generated_ids]

midi_obj = tokens_to_midi(generated_tokens)
midi_obj.write('generated_midi2.mid')

generated_ids = sample(trained_model, start_token, max_length=512, device=device)
generated_tokens = [ID_TO_TOKEN[i] for i in generated_ids]

midi_obj = tokens_to_midi(generated_tokens)
midi_obj.write('generated_midi3.mid')


generated_ids = sample(trained_model, start_token, max_length=512, device=device)
generated_tokens = [ID_TO_TOKEN[i] for i in generated_ids]

midi_obj = tokens_to_midi(generated_tokens)
midi_obj.write('generated_midi4.mid')


generated_ids = sample(trained_model, start_token, max_length=512, device=device)
generated_tokens = [ID_TO_TOKEN[i] for i in generated_ids]

midi_obj = tokens_to_midi(generated_tokens)
midi_obj.write('generated_midi5.mid')





### Train Soloists

In [17]:
def valid_soloist_midi_files(filepaths, instrument_program):
    valid_files = []
    for i, filepath in enumerate(filepaths):
        try:
            midi_data = pretty_midi.PrettyMIDI(filepath)
            if len(midi_data.instruments) > 0 and any(instrument.program == instrument_program for instrument in midi_data.instruments):
                valid_files.append(filepath)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        print_progress_bar(i+1, len(filepaths), prefix='Validating MIDI files')
    return valid_files

In [18]:
unique_instruments = [80, 81, 38, 121]

In [19]:
soloist_80_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=80)
soloist_81_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=81)
soloist_38_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=38)
soloist_121_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=121)

soloist_80_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=80)
soloist_81_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=81)
soloist_38_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=38)
soloist_121_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=121)

Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete


In [20]:
def get_soloist_vocabulary():
    vocabulary = dict()
    index = 0

    for special_token in [PAD_TOKEN, BEGINNING_OF_SONG_TOKEN, END_OF_SONG_TOKEN]:
        vocabulary[special_token] = index
        index += 1

    for time_shift in range(1, MAX_SHIFT_STEPS + 1):
        vocabulary[f'time_shift_{time_shift}'] = index
        index += 1

    for action in ["note_on", "note_off"]:
        for pitch in range(128):
            vocabulary[f'{action}_{pitch}'] = index
            index += 1
    return vocabulary

In [21]:
def get_soloist_id_to_token():
    vocabulary = get_soloist_vocabulary()
    return {v: k for k, v in vocabulary.items()}

In [22]:
def soloist_midi_to_tokens(pm: pretty_midi.PrettyMIDI, instrument_program):
    events = []

    for instrument in pm.instruments:
        if instrument.program == instrument_program:
            for note in instrument.notes:
                events.append((note.start, f'note_on_{note.pitch}'))
                events.append((note.end, f'note_off_{note.pitch}'))
        
            break
        
    
    events.sort()  # Sort by time

    tokens = []
    last_time = 0.0
    for time, event in events:
        delta = time - last_time
        steps = max(round(delta / TIME_SHIFT_RESOLUTION), 1) # force at least 1 step (no side by side notes)

        while steps > 0:
            shift = min(steps, MAX_SHIFT_STEPS)
            tokens.append(f'time_shift_{shift}')
            steps -= shift
        
        tokens.append(event)
        last_time = time
    return [BEGINNING_OF_SONG_TOKEN] + tokens + [END_OF_SONG_TOKEN]

def soloist_tokens_to_midi(tokens, instrument_program):
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=instrument_program)

    active_notes = dict()
    active_pitch = None
    active_start = None

    current_time = 0.0
    for token in tokens:
        # print(current_time, token)
        if token.startswith('time_shift_'):
            shift_steps = int(token.split('_')[-1])
            current_time += shift_steps * TIME_SHIFT_RESOLUTION
        elif token.startswith('note_on_'):
            if active_pitch is not None:
                print(f"Warning: Skipping {token}, other note: {active_pitch} is still active.")
                continue 
            pitch = int(token.split('_')[2])
            # active_notes[pitch] = current_time
            active_pitch = pitch
            active_start = current_time
        elif token.startswith('note_off_'):
            pitch = int(token.split('_')[2])
            if pitch != active_pitch:
                print(f"Warning: Note off for {pitch} without matching note on.")
                continue

            if current_time > active_start:
                note = pretty_midi.Note(
                    velocity=100, pitch=pitch, start=active_start, end=current_time
                )

                instrument.notes.append(note)
            else:
                print(f"Warning: Note off for {pitch} at {current_time} note after note on at {active_start}. Ignoring.")
            # del active_notes[pitch]
            active_pitch = None
            active_start = None

    pm.instruments.append(instrument)
    
    return pm

In [22]:
midi = pretty_midi.PrettyMIDI(all_filepaths[0])
tokens_80 = soloist_midi_to_tokens(midi, 80)
tokens_81 = soloist_midi_to_tokens(midi, 81)
tokens_38 = soloist_midi_to_tokens(midi, 38)
tokens_121 = soloist_midi_to_tokens(midi, 121)

midi_reconstructed_80 = soloist_tokens_to_midi(tokens_80, instrument_program=80)
midi_reconstructed_81 = soloist_tokens_to_midi(tokens_81, instrument_program=81)
midi_reconstructed_38 = soloist_tokens_to_midi(tokens_38, instrument_program=38)
midi_reconstructed_121 = soloist_tokens_to_midi(tokens_121, instrument_program=121)

# midi_reconstructed_80.write('reconstructed_midi_80.mid')
# midi_reconstructed_81.write('reconstructed_midi_81.mid')
# midi_reconstructed_38.write('reconstructed_midi_38.mid')
# midi_reconstructed_121.write('reconstructed_midi_121.mid')

# midi.write('original_midi.mid')

In [23]:
def load_soloist_sequences(filepaths, instrument_program):
    vocabulary = get_soloist_vocabulary()
    sequences = []
    for i, filepath in enumerate(filepaths):
        pm = pretty_midi.PrettyMIDI(filepath)
        tokens = soloist_midi_to_tokens(pm, instrument_program)
        if not tokens:
            raise ValueError(f'No tokens generated for {filepath}')
        sequences.append([vocabulary[token] for token in tokens])
        print_progress_bar(i+1, len(filepaths), prefix='Loading sequences')
    return sequences

In [28]:
soloist_80_train_sequences = load_soloist_sequences(soloist_80_midi_train_filepaths, instrument_program=80)
soloist_81_train_sequences = load_soloist_sequences(soloist_81_midi_train_filepaths, instrument_program=81)
soloist_38_train_sequences = load_soloist_sequences(soloist_38_midi_train_filepaths, instrument_program=38)
soloist_121_train_sequences = load_soloist_sequences(soloist_121_midi_train_filepaths, instrument_program=121)

soloist_80_val_sequences = load_soloist_sequences(soloist_80_midi_val_filepaths, instrument_program=80)
soloist_81_val_sequences = load_soloist_sequences(soloist_81_midi_val_filepaths, instrument_program=81)
soloist_38_val_sequences = load_soloist_sequences(soloist_38_midi_val_filepaths, instrument_program=38)
soloist_121_val_sequences = load_soloist_sequences(soloist_121_midi_val_filepaths, instrument_program=121)

NameError: name 'soloist_80_midi_train_filepaths' is not defined

In [None]:
def get_soloist_model():
    return MusicRNN(vocab_size=len(get_soloist_vocabulary()), embedding_dim=64, hidden_dim=512, num_layers=2)

In [24]:
def get_soloist_model_transformer():
    return  MusicTransformer(
    vocab_size=len(get_soloist_vocabulary()),
    d_model=256,           # Model dimension (replaces embedding_dim)
    nhead=8,               # Number of attention heads
    num_layers=2,          # Number of transformer layers
    dim_feedforward=2048,  # MLP hidden dimension
    dropout=0.1,           # Dropout rate
    max_seq_length=512    # Maximum sequence length
)

In [37]:
def train_soloists(soloist_sequences: dict):
    models = {}
    for instrument_program, (train_sequences, val_sequences) in soloist_sequences.items():
        train_dataset = MIDITokenDataset(train_sequences, seq_length=512)
        val_dataset = MIDITokenDataset(val_sequences, seq_length=512)

        batch_size = 8
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = get_soloist_model()
        train(model, train_loader, val_loader, vocab_size=len(get_soloist_vocabulary()), device=device)
        models[instrument_program] = model
    return models


In [25]:

def train_soloists_transformer(soloist_sequences: dict):
    models = {}
    for instrument_program, (train_sequences, val_sequences) in soloist_sequences.items():
        train_dataset = MIDITokenDataset(train_sequences, seq_length=512)
        val_dataset = MIDITokenDataset(val_sequences, seq_length=512)

        batch_size = 8
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = get_soloist_model_transformer()
        train(model, train_loader, val_loader, vocab_size=len(get_soloist_vocabulary()), device=device)
        models[instrument_program] = model
    return models


In [31]:
soloist_sequences = {
    80: (soloist_80_train_sequences, soloist_80_val_sequences),
    81: (soloist_81_train_sequences, soloist_81_val_sequences),
    38: (soloist_38_train_sequences, soloist_38_val_sequences),
    121: (soloist_121_train_sequences, soloist_121_val_sequences)
}
soloist_models = train_soloists(soloist_sequences)

NameError: name 'soloist_80_train_sequences' is not defined

In [28]:
soloist_sequences = {
    80: (soloist_80_train_sequences, soloist_80_val_sequences),
    81: (soloist_81_train_sequences, soloist_81_val_sequences),
    38: (soloist_38_train_sequences, soloist_38_val_sequences),
    121: (soloist_121_train_sequences, soloist_121_val_sequences)
}
soloist_models = train_soloists_transformer(soloist_sequences)

Processing sequences |███████████████████████████████████████████-------| 87.9% Complete

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
E01 train | loss 1.6864 | ppl   5.40 | acc 0.580
E01 val   | loss 1.0829 | ppl   2.95 | acc 0.702

E02 train | loss 0.9695 | ppl   2.64 | acc 0.725
E02 val   | loss 0.9542 | ppl   2.60 | acc 0.737

E03 train | loss 0.8705 | ppl   2.39 | acc 0.751
E03 val   | loss 0.9450 | ppl   2.57 | acc 0.742

E04 train | loss 0.8233 | ppl   2.28 | acc 0.762
E04 val   | loss 0.9467 | ppl   2.58 | acc 0.741

E05 train | loss 0.7884 | ppl   2.20 | acc 0.771
E05 val   | loss 0.9164 | ppl   2.50 | acc 0.750

E06 train | loss 0.7572 | ppl   2.13 | acc 0.780
E06 val   | loss 0.8804 | ppl   2.41 | acc 0.759

E07 train | loss 0.7263 | ppl   2.07 | acc 0.789
E07 val   | loss 0.8210 | ppl   2.27 | acc 0.781

E08 train | loss 0.6853 | ppl   1.98 | acc 0.802
E08 val   | loss 0.7839 | ppl   2.19 | acc 0.794

E09 train | loss 0.6584 | ppl   1.93 | acc 0.810
E09 val   | loss 0.7784 | ppl   2.18 | acc 0.797

E10 train | loss 0.

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
E01 train | loss 1.6761 | ppl   5.34 | acc 0.580
E01 val   | loss 1.0120 | ppl   2.75 | acc 0.724

E02 train | loss 0.9537 | ppl   2.60 | acc 0.728
E02 val   | loss 0.9039 | ppl   2.47 | acc 0.754

E03 train | loss 0.8656 | ppl   2.38 | acc 0.751
E03 val   | loss 0.8346 | ppl   2.30 | acc 0.776

E04 train | loss 0.8184 | ppl   2.27 | acc 0.763
E04 val   | loss 0.8566 | ppl   2.36 | acc 0.767

E05 train | loss 0.7857 | ppl   2.19 | acc 0.772
E05 val   | loss 0.8168 | ppl   2.26 | acc 0.780

E06 train | loss 0.7454 | ppl   2.11 | acc 0.785
E06 val   | loss 0.7439 | ppl   2.10 | acc 0.803

E07 train | loss 0.6999 | ppl   2.01 | acc 0.800
E07 val   | loss 0.7053 | ppl   2.02 | acc 0.815

E08 train | loss 0.6757 | ppl   1.97 | acc 0.807
E08 val   | loss 0.6977 | ppl   2.01 | acc 0.815

E09 train | loss 0.6522 | ppl   1.92 | acc 0.813
E09 val   | loss 0.6717 | ppl   1.96 | acc 0.823

E10 train | loss 0.

In [48]:
def save_model_weights(models):
    for instrument_program, model in models.items():
        torch.save(model.state_dict(), f'soloist_model_{instrument_program}.pth')

In [59]:
print(soloist_models[80])
torch.save(soloist_models[80].state_dict(), f'soloist_model_80kio.pth')

MusicTransformer(
  (embedding): Embedding(359, 256)
  (pos_embedding): Embedding(512, 256)
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (multihead_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
        )
        (linear1): Linear(in_features=256, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=256, bias=True)
        (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [49]:
save_model_weights(soloist_models)

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [26]:
def load_model_weight(path, model):
    model.load_state_dict(torch.load(path, map_location='cpu'))
    return model

In [None]:
def sample_soloists(models, outdir='./'):
    soloist_vocabulary = get_soloist_vocabulary()
    midis = dict()
    for instrument_program, model in models.items():
        model.eval()  # Set to eval mode for inference
        start_token = soloist_vocabulary[BEGINNING_OF_SONG_TOKEN]
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        generated_sequence = sample(model, start_token, max_length=1024, device=device)
        generated_tokens = [get_soloist_id_to_token()[token] for token in generated_sequence]
        tokens_reconstructed = soloist_tokens_to_midi(generated_tokens, instrument_program=instrument_program)
        midis[instrument_program] = tokens_reconstructed
    
    together_midi = pretty_midi.PrettyMIDI()
    for instrument_program, midi in midis.items():
        together_midi.instruments.extend(midi.instruments)

    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, f'{datetime.datetime.now()}.mid')
    
    together_midi.write(outpath)


In [13]:
from collections import OrderedDict     # <- add this
import os, datetime, torch, pretty_midi
from typing import Dict
def sample(
    model,
    start_token: int,
    vocab: Dict[str, int],
    max_length: int = 1024,
    temperature: float = 1.0,
    device: str = "cuda",
):
    """Autoregressive sampling for our MusicTransformer variant."""
    model.to(device).eval()

    eos_id = vocab[END_OF_SONG_TOKEN]
    pad_id = vocab[PAD_TOKEN]

    generated = [start_token]
    tokens = torch.tensor([generated], dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_length - 1):
            logits, _ = model(tokens)          # (1, T, |V|)
            logits = logits[:, -1, :] / temperature
            probs  = torch.softmax(logits, dim=-1)

            next_tok = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_tok)

            if next_tok in (eos_id, pad_id):
                break

            tokens = torch.tensor([generated], dtype=torch.long, device=device)

    return generated
    
def sample_soloists(
    models: Dict[int, torch.nn.Module],
    outdir: str = "./",
    max_length: int = 512,
    temperature: float = 1.0,
    device: str = None,
):
    """
    Generate a track for each instrument-program in `models`
    and merge them into one multitrack MIDI file.
    """
    device = (
        device
        or ("cuda" if torch.cuda.is_available() else "cpu")
    )

    vocab          = get_soloist_vocabulary()
    id_to_token    = get_soloist_id_to_token()
    midis          = OrderedDict()

    # --- generate each solo line ------------------------------------------------
    for prog, model in models.items():
        start_tok   = vocab[BEGINNING_OF_SONG_TOKEN]
        token_ids   = sample(
            model,
            start_tok,
            vocab=vocab,
            max_length=max_length,
            temperature=temperature,
            device=device,
        )
        tokens      = [id_to_token[i] for i in token_ids]
        midi_track  = soloist_tokens_to_midi(tokens, instrument_program=prog)
        midis[prog] = midi_track

    # --- merge into a single PrettyMIDI ----------------------------------------
    together = pretty_midi.PrettyMIDI()
    for midi in midis.values():
        together.instruments.extend(midi.instruments)

    # --- make sure output dir exists and save ----------------------------------
    os.makedirs(outdir, exist_ok=True)
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    outpath   = os.path.join(outdir, f"soloists_{timestamp}.mid")

    together.write(outpath)
    print(f"✨  Wrote multitrack MIDI to: {outpath}")

    return outpath

In [20]:
unique_instruments = [38]

In [21]:
# models = [load_model_weight(f'soloist_model_{instrument_program}.pth', get_soloist_model()) for instrument_program in unique_instruments]
model_weights_dir = os.path.join('lstm_models', 'soloists', '2')
# models = {80: load_model_weight(os.path.join(model_weights_dir, 'soloist_model_80.pth'), get_soloist_model())}
models = {instr_program: load_model_weight( f'soloist_model_{instr_program}.pth', get_soloist_model_transformer()) for instr_program in unique_instruments}
sample_soloists(models, outdir=os.path.join(model_weights_dir, 'samples'))

✨  Wrote multitrack MIDI to: lstm_models/soloists/2/samples/soloists_20250602_034729.mid


'lstm_models/soloists/2/samples/soloists_20250602_034729.mid'

In [None]:
side_by_side_note_count = 0
total_tokens = 0
for i, sequence in enumerate(soloist_80_train_sequences):
    pairwise = zip(sequence[:-1], sequence[1:])
    total_tokens += len(sequence)
    for token1, token2 in pairwise:
        token1 = get_soloist_id_to_token()[token1]
        token2 = get_soloist_id_to_token()[token2]
        if token1.startswith('note_on_') and token2.startswith('note_off_'):
            if token1.split('_')[2] == token2.split('_')[2]:
                side_by_side_note_count += 1

    print_progress_bar(i+1, len(soloist_80_train_sequences), prefix='Counting side by side notes')
    
print(f'Side by side note count: {side_by_side_note_count}')
print(f'Total tokens: {total_tokens}')

### Task 2 - Conditional generation

In [23]:
def extract_game_name(filepath):
    filename = os.path.basename(filepath)
    game_name_chunks = filename.split('_')[1:-2]
    return '_'.join(game_name_chunks)

def print_game_distribution(filepaths):
    game_names = [extract_game_name(filepath) for filepath in filepaths]
    game_name_dist = Counter(game_names)
    game_counts = game_name_dist.values()
    game_counts_dist = Counter(game_counts)
    print(f'Game name distribution: {dict(game_name_dist)}')
    print(f'Unique game names: {len(game_name_dist)}')
    print(f'Distribution of game counts: {dict(game_counts_dist)}')

print_game_distribution(all_filepaths)

Game name distribution: {'10_YardFight': 2, '1942': 6, '720_': 4, '98in1': 1, 'Abadox_TheDeadlyInnerWar': 12, 'Adam_amp_Eve': 5, 'AfterBurner': 5, 'AfterBurnerII': 6, 'AighinanoYogen_BalubalouknoDensetsuYori': 14, 'AlienSyndrome': 10, 'Aliens_Alien2': 9, 'Argus': 10, 'ArmWrestling': 1, 'ArumananoKiseki': 12, 'Athena': 13, 'AtlantisnoNazo': 7, 'BabelnoTou': 8, 'BalloonFight': 12, 'Baseball': 6, 'BatmanReturns': 20, 'BinaryLand': 6, 'Batman_ReturnofTheJoker': 15, 'Batman_TheVideoGame': 11, 'BattleCity': 3, 'BioMiracleBokutteUpa': 13, 'BioSenshiDan_IncreasertonoTatakai': 24, 'Blackjack': 3, 'BlasterMaster': 14, 'Bomberman': 10, 'BombermanII': 18, 'BuraiFighter': 9, 'CaptainTsubasaVol_II_SuperStriker': 40, 'Castelian': 5, 'CastleofDragon': 16, 'Castlevania': 16, 'CastlevaniaIII_Dracula_sCurse': 28, 'CastlevaniaII_Simon_sQuest': 9, 'Chack_nPop': 5, 'Challenger': 6, 'ChaosWorld': 24, 'ChesterField_EpisodeIIAnkokuShinenoChousen': 19, 'ChoujinSentaiJetman': 18, 'CircusCaper': 26, 'CircusCharli

In [24]:
MIN_GAME_COUNT = 10
game_names = [extract_game_name(filepath) for filepath in all_filepaths]
game_name_dist = Counter(game_names)
valid_conditional_filepaths = [filepath for filepath in all_filepaths if game_name_dist[extract_game_name(filepath)] >= MIN_GAME_COUNT]
print("original filepaths count:", len(all_filepaths))
print("valid conditional filepaths count:", len(valid_conditional_filepaths))
print("number of unique games:", sum(1 for count in game_name_dist.values() if count >= MIN_GAME_COUNT))

original filepaths count: 5244
valid conditional filepaths count: 4418
number of unique games: 243


In [25]:
VALID_GAME_NAMES = set([extract_game_name(filepath) for filepath in valid_conditional_filepaths])
def get_soloist_conditional_vocabulary():
    vocabulary = dict()
    index = 0

    for special_token in [PAD_TOKEN, BEGINNING_OF_SONG_TOKEN, END_OF_SONG_TOKEN]:
        vocabulary[special_token] = index
        index += 1

    for time_shift in range(1, MAX_SHIFT_STEPS + 1):
        vocabulary[f'time_shift_{time_shift}'] = index
        index += 1

    for action in ["note_on", "note_off"]:
        for pitch in range(128):
            vocabulary[f'{action}_{pitch}'] = index
            index += 1
            
    for game_name in VALID_GAME_NAMES:
        vocabulary[f'game_{game_name}'] = index
        index += 1

    return vocabulary

def get_soloist_conditional_id_to_token():
    vocabulary = get_soloist_conditional_vocabulary()
    return {v: k for k, v in vocabulary.items()}

In [26]:
soloist_conditional_vocabulary = get_soloist_conditional_vocabulary()
valid_game_name_counts = Counter([extract_game_name(filepath) for filepath in valid_conditional_filepaths])

print("total soloist conditional vocabulary length", len(soloist_conditional_vocabulary))
print("(game_name, count) sorted: ", sorted(valid_game_name_counts.items(), key=lambda x: x[1], reverse=True)[:10])

total soloist conditional vocabulary length 602
(game_name, count) sorted:  [('GanbareGoemonGaiden2_TenkanoZaih_', 99), ('FinalFantasyIII', 59), ('DragonWarriorIV', 46), ('WaiWaiWorld2_SOS__ParsleyJou', 46), ('Kirby_sAdventure', 42), ('CaptainTsubasaVol_II_SuperStriker', 40), ('Mother', 40), ('RadiaSenki_ReimeiHen', 39), ('MetalMax', 38), ('FamicomJump_HeroRetsuden', 36)]


In [27]:
def conditional_soloist_filepath_to_tokens(filepath, instrument_program, midi_file = None):
    game_name = extract_game_name(filepath)
    if game_name not in VALID_GAME_NAMES:
        raise ValueError(f'Game name {game_name} not in valid game names.')
    
    pm = midi_file if midi_file else pretty_midi.PrettyMIDI(filepath)
    tokens = soloist_midi_to_tokens(pm, instrument_program=instrument_program)
    tokens = [tokens[0]] + [f'game_{game_name}'] + tokens[1:] # prefix the sequence with the game name token
    return tokens

In [28]:
example_filepath = valid_conditional_filepaths[0]
example_tokens = conditional_soloist_filepath_to_tokens(valid_conditional_filepaths[0], instrument_program=80)
example_ids = [get_soloist_conditional_vocabulary()[token] for token in example_tokens]
print(f'Example tokens for {os.path.basename(example_filepath)}: {example_tokens[:10]}')
print(f'Example IDs for {os.path.basename(example_filepath)}: {example_ids[:10]}')

Example tokens for 005_Abadox_TheDeadlyInnerWar_00_01OpeningSE.mid: ['<BOS>', 'game_Abadox_TheDeadlyInnerWar', 'time_shift_1', 'note_on_45', 'time_shift_1', 'note_off_45', 'time_shift_1', 'note_on_44', 'time_shift_5', 'note_off_44']
Example IDs for 005_Abadox_TheDeadlyInnerWar_00_01OpeningSE.mid: [1, 392, 3, 148, 3, 276, 3, 147, 7, 275]


In [29]:
tokens_80 = conditional_soloist_filepath_to_tokens(example_filepath, 80)
tokens_81 = conditional_soloist_filepath_to_tokens(example_filepath, 81)
tokens_38 = conditional_soloist_filepath_to_tokens(example_filepath, 38)
tokens_121 = conditional_soloist_filepath_to_tokens(example_filepath, 121)

midi_reconstructed_80 = soloist_tokens_to_midi(tokens_80, instrument_program=80)
midi_reconstructed_81 = soloist_tokens_to_midi(tokens_81, instrument_program=81)
midi_reconstructed_38 = soloist_tokens_to_midi(tokens_38, instrument_program=38)
midi_reconstructed_121 = soloist_tokens_to_midi(tokens_121, instrument_program=121)

# midi_reconstructed_80.write('reconstructed_midi_80.mid')
# midi_reconstructed_81.write('reconstructed_midi_81.mid')
# midi_reconstructed_38.write('reconstructed_midi_38.mid')
# midi_reconstructed_121.write('reconstructed_midi_121.mid')

# original_midi = pretty_midi.PrettyMIDI(example_filepath)
# original_midi.write('original_midi.mid')

In [30]:
def load_conditional_soloist_sequences(filepaths, instruments):
    vocabulary = get_soloist_conditional_vocabulary()
    sequences = {instrument: [] for instrument in instruments}

    for i, filepath in enumerate(filepaths):
        midi_file = pretty_midi.PrettyMIDI(filepath)
        for instrument in instruments:
            tokens = conditional_soloist_filepath_to_tokens(filepath, instrument, midi_file=midi_file)
            if not tokens:
                raise ValueError(f'No tokens generated for {filepath} with instrument {instrument}')
            sequences[instrument].append([vocabulary[token] for token in tokens])
        print_progress_bar(i+1, len(filepaths), prefix='Loading sequences')
    return sequences

In [31]:
all_sequences = load_conditional_soloist_sequences(valid_conditional_filepaths, unique_instruments)

Loading sequences |██████████████████████████████████████████████████| 100.0% Complete


In [32]:
all_conditional_80_sequences = all_sequences[80]
all_conditional_81_sequences = all_sequences[81]
all_conditional_38_sequences = all_sequences[38]
all_conditional_121_sequences = all_sequences[121]

train_conditional_80_sequences, val_conditional_80_sequences = train_test_split(all_conditional_80_sequences, test_size=0.2, random_state=42)
train_conditional_81_sequences, val_conditional_81_sequences = train_test_split(all_conditional_81_sequences, test_size=0.2, random_state=42)
train_conditional_38_sequences, val_conditional_38_sequences = train_test_split(all_conditional_38_sequences, test_size=0.2, random_state=42)
train_conditional_121_sequences, val_conditional_121_sequences = train_test_split(all_conditional_121_sequences, test_size=0.2, random_state=42)


In [41]:
def get_conditional_soloist_model():
    return MusicRNN(vocab_size=len(get_soloist_conditional_vocabulary()), embedding_dim=64, hidden_dim=512, num_layers=2)

In [33]:
def get_conditional_soloist_model():
    return  MusicTransformer(
    vocab_size=len(get_soloist_conditional_vocabulary()),
    d_model=256,           # Model dimension (replaces embedding_dim)
    nhead=8,               # Number of attention heads
    num_layers=2,          # Number of transformer layers
    dim_feedforward=2048,  # MLP hidden dimension
    dropout=0.1,           # Dropout rate
    max_seq_length=512    # Maximum sequence length
)

In [34]:
def train_conditional_soloists(soloist_sequences: dict):
    models = {}
    for instrument_program, (train_sequences, val_sequences) in soloist_sequences.items():
        train_dataset = MIDITokenDataset(train_sequences, seq_length=512)
        val_dataset = MIDITokenDataset(val_sequences, seq_length=512)

        batch_size = 32
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = get_conditional_soloist_model()
        train(model, train_loader, val_loader, vocab_size=len(get_soloist_conditional_vocabulary()), device=device)
        models[instrument_program] = model
    return models

In [35]:
soloist_sequences = {
    80: (train_conditional_80_sequences, val_conditional_80_sequences),
    81: (train_conditional_81_sequences, val_conditional_81_sequences),
    38: (train_conditional_38_sequences, val_conditional_38_sequences),
    121: (train_conditional_121_sequences, val_conditional_121_sequences)
}
soloist_models = train_conditional_soloists(soloist_sequences)

Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
E01 train | loss 2.3007 | ppl   9.98 | acc 0.479
E01 val   | loss 1.7419 | ppl   5.71 | acc 0.538

E02 train | loss 1.4319 | ppl   4.19 | acc 0.611
E02 val   | loss 1.1517 | ppl   3.16 | acc 0.670

E03 train | loss 1.0600 | ppl   2.89 | acc 0.696
E03 val   | loss 0.9910 | ppl   2.69 | acc 0.713

E04 train | loss 0.9357 | ppl   2.55 | acc 0.730
E04 val   | loss 0.9187 | ppl   2.51 | acc 0.737

E05 train | loss 0.8685 | ppl   2.38 | acc 0.749
E05 val   | loss 0.8842 | ppl   2.42 | acc 0.748

E06 train | loss 0.8224 | ppl   2.28 | acc 0.761
E06 val   | loss 0.8620 | ppl   2.37 | acc 0.754

E07 train | loss 0.7858 | ppl   2.19 | acc 0.771
E07 val   | loss 0.8501 | ppl   2.34 | acc 0.757

E08 train | loss 0.7560 | ppl   2.13 | acc 0.778
E08 val   | loss 0.8385 | ppl   2.31 | acc 0.760

E09 train | loss 0.7302 | pp

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
E01 train | loss 2.3455 | ppl  10.44 | acc 0.458
E01 val   | loss 1.8152 | ppl   6.14 | acc 0.513

E02 train | loss 1.4387 | ppl   4.22 | acc 0.598
E02 val   | loss 1.2063 | ppl   3.34 | acc 0.657

E03 train | loss 1.0170 | ppl   2.76 | acc 0.705
E03 val   | loss 0.9769 | ppl   2.66 | acc 0.717

E04 train | loss 0.8740 | ppl   2.40 | acc 0.745
E04 val   | loss 0.9301 | ppl   2.53 | acc 0.731

E05 train | loss 0.8121 | ppl   2.25 | acc 0.760
E05 val   | loss 0.9081 | ppl   2.48 | acc 0.736

E06 train | loss 0.7693 | ppl   2.16 | acc 0.771
E06 val   | loss 0.8913 | ppl   2.44 | acc 0.740

E07 train | loss 0.7365 | ppl   2.09 | acc 0.780
E07 val   | loss 0.8841 | ppl   2.42 | acc 0.743

E08 train | loss 0.7088 | ppl   2.03 | acc 0.787
E08 val   | loss 0.8783 | ppl   2.41 | acc 0.744

E09 train | loss 0.6853 | ppl   1.98 | acc 0.794
E09 val   | loss 0.8784 | ppl   2.41 | acc 0.747

E10 train | loss 0.

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
E01 train | loss 1.7553 | ppl   5.79 | acc 0.507
E01 val   | loss 1.2424 | ppl   3.46 | acc 0.581

E02 train | loss 1.0591 | ppl   2.88 | acc 0.633
E02 val   | loss 0.8802 | ppl   2.41 | acc 0.678

E03 train | loss 0.7983 | ppl   2.22 | acc 0.714
E03 val   | loss 0.6520 | ppl   1.92 | acc 0.770

E04 train | loss 0.5808 | ppl   1.79 | acc 0.799
E04 val   | loss 0.4766 | ppl   1.61 | acc 0.842

E05 train | loss 0.4637 | ppl   1.59 | acc 0.842
E05 val   | loss 0.4171 | ppl   1.52 | acc 0.863

E06 train | loss 0.4079 | ppl   1.50 | acc 0.861
E06 val   | loss 0.3980 | ppl   1.49 | acc 0.868

E07 train | loss 0.3772 | ppl   1.46 | acc 0.871
E07 val   | loss 0.3906 | ppl   1.48 | acc 0.870

E08 train | loss 0.3550 | ppl   1.43 | acc 0.878
E08 val   | loss 0.3811 | ppl   1.46 | acc 0.872

E09 train | loss 0.3428 | ppl   1.41 | acc 0.881
E09 val   | loss 0.3815 | ppl   1.46 | acc 0.873

E10 train | loss 0.

In [39]:
# --- imports ---------------------------------------------------------------
import os, datetime, torch, torch.nn.functional as F, pretty_midi
from pathlib import Path
from collections import OrderedDict
from typing import Dict, List

# -------------------------------------------------------------------------- #
# 1.  Conditional autoregressive sampling for a MusicTransformer
# -------------------------------------------------------------------------- #
def sample_conditional_transformer(
    model: torch.nn.Module,
    prefix_tokens: List[int],                 # e.g. [BOS, game_token]
    vocab: Dict[str, int],
    max_length: int = 512,                    # keep ≤ pos-embed size
    temperature: float = 1.0,
    device: str = "cuda",
) -> List[int]:
    """
    Generate a continuation given a fixed token prefix (no gradient).

    Returns the full sequence: prefix + generated suffix.
    """
    model = model.to(device).eval()

    # Fast sanity-check: do we exceed the learned positional table?
    seq_cap = getattr(model, "max_seq_length", model.pos_embedding.num_embeddings)
    if max_length > seq_cap:
        max_length = seq_cap                    # hard-clip to avoid CUDA asserts

    eos_id, pad_id = vocab[END_OF_SONG_TOKEN], vocab[PAD_TOKEN]

    generated = list(prefix_tokens)
    tokens = torch.tensor([generated], dtype=torch.long, device=device)

    with torch.no_grad():
        for _ in range(max_length - len(prefix_tokens)):
            logits, _ = model(tokens)           # (1, T, |V|)
            logits = logits[:, -1, :] / temperature
            probs  = F.softmax(logits, dim=-1)

            next_tok = torch.multinomial(probs, 1).item()
            generated.append(next_tok)

            if next_tok in (eos_id, pad_id):
                break

            tokens = torch.tensor([generated], dtype=torch.long, device=device)

    return generated

# -------------------------------------------------------------------------- #
# 2.  Convenience wrapper: generate every solo line and merge to one MIDI
# -------------------------------------------------------------------------- #
def sample_conditional_soloists(
    models: Dict[int, torch.nn.Module],
    game_name: str,
    outdir: str = "./",
    max_length: int = 512,
    temperature: float = 1.0,
    device: str | None = None,
) -> Path:
    """
    For each instrument program in `models`, generate a transformer-based
    conditional solo line (conditioned on `game_name`) and write a merged MIDI.

    Returns the path of the generated file.
    """
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    vocab         = get_soloist_conditional_vocabulary()
    id_to_token   = get_soloist_conditional_id_to_token()
    game_tok_key  = f"game_{game_name}"
    if game_tok_key not in vocab:
        raise KeyError(f"'{game_tok_key}' not found in conditional vocabulary")

    prefix = [vocab[BEGINNING_OF_SONG_TOKEN], vocab[game_tok_key]]

    midis = OrderedDict()
    for prog, model in models.items():
        tokens_ids = sample_conditional_transformer(
            model,
            prefix_tokens=prefix,
            vocab=vocab,
            max_length=max_length,
            temperature=temperature,
            device=device,
        )
        tokens_txt = [id_to_token[i] for i in tokens_ids]
        midi_track = soloist_tokens_to_midi(tokens_txt, instrument_program=prog)
        midis[prog] = midi_track

    # Merge the generated instrument tracks
    merged = pretty_midi.PrettyMIDI()
    for midi in midis.values():
        merged.instruments.extend(midi.instruments)

    outdir = Path(outdir) / game_name
    outdir.mkdir(parents=True, exist_ok=True)
    outpath = outdir / f"{int(datetime.datetime.now().timestamp()*1000)}.mid"
    merged.write(str(outpath))     # or outpath.as_posix()


    print(f"🎮  Conditional soloists written to {outpath}")
    return outpath


In [40]:
games_to_sample = [
    "GanbareGoemonGaiden2_TenkanoZaih_", "FinalFantasyIII",
    "WaiWaiWorld2_SOS__ParsleyJou", "DragonWarriorIV", "Kirby_sAdventure"
]

weights_dir = Path("transformer_models") / "conditioned_soloists" / "1"
# conditional_models = {
#     prog: load_model_weight(
#         weights_dir / f"soloist_conditional_model_{prog}.pth",
#         get_conditional_soloist_transformer()  # <-- transformer factory
#     )
#     for prog in unique_instruments
# }

for game in games_to_sample:
    sample_conditional_soloists(
        soloist_models,
        game,
        outdir=weights_dir / "samples",
        max_length=512,         # must be ≤ pos_embed size
        temperature=1.1,
    )


🎮  Conditional soloists written to transformer_models/conditioned_soloists/1/samples/GanbareGoemonGaiden2_TenkanoZaih_/1748910911958.mid
🎮  Conditional soloists written to transformer_models/conditioned_soloists/1/samples/FinalFantasyIII/1748910932632.mid
🎮  Conditional soloists written to transformer_models/conditioned_soloists/1/samples/WaiWaiWorld2_SOS__ParsleyJou/1748910949039.mid
🎮  Conditional soloists written to transformer_models/conditioned_soloists/1/samples/DragonWarriorIV/1748910964348.mid
🎮  Conditional soloists written to transformer_models/conditioned_soloists/1/samples/Kirby_sAdventure/1748910982547.mid


In [None]:
def sample_conditional(model, start_tokens, max_length=100, temperature=1.0, device='cuda'):
    model = model.to(device)
    model.eval()

    generated = start_tokens
    input_token = torch.tensor([generated], device=device)  # (1, 1)

    hidden = None

    for _ in range(max_length):
        output, hidden = model(input_token, hidden)  # output: (1, 1, vocab_size)
        output = output[:, -1, :]  # take the last output
        output = output / temperature  # adjust randomness

        probs = F.softmax(output, dim=-1)  # (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        if next_token == VOCABULARY[END_OF_SONG_TOKEN] or VOCABULARY[PAD_TOKEN]: # reach end of sequence
          break

        input_token = torch.tensor([[next_token]], device=device)

    return generated

In [None]:
def sample_conditional_soloists(models, game_name, outdir='./'):
    soloist_vocabulary = get_soloist_conditional_vocabulary()
    midis = dict()
    for instrument_program, model in models.items():
        model.eval()  # Set to eval mode for inference
        game_token = f'game_{game_name}'
        start_tokens = [soloist_vocabulary[BEGINNING_OF_SONG_TOKEN], soloist_vocabulary[game_token]]
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        generated_sequence = sample_conditional(model, start_tokens, max_length=1024, device=device)
        generated_tokens = [get_soloist_conditional_id_to_token()[token] for token in generated_sequence]
        tokens_reconstructed = soloist_tokens_to_midi(generated_tokens, instrument_program=instrument_program)
        midis[instrument_program] = tokens_reconstructed
    
    together_midi = pretty_midi.PrettyMIDI()
    for instrument_program, midi in midis.items():
        together_midi.instruments.extend(midi.instruments)

    outdir = os.path.join(outdir, game_name)
    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, f'{int(datetime.datetime.now().timestamp() * 1000)}.mid')
    
    together_midi.write(outpath)

In [None]:
games_to_sample = ['GanbareGoemonGaiden2_TenkanoZaih_', 'FinalFantasyIII', 'WaiWaiWorld2_SOS__ParsleyJou', 'DragonWarriorIV', 'Kirby_sAdventure']
model_weights_dir = os.path.join('lstm_models', 'conditioned_soloists', '1')
conditioned_models = {instr_program: load_model_weight(os.path.join(model_weights_dir, f'soloist_conditional_model_{instr_program}.pth'), get_conditional_soloist_model()) for instr_program in unique_instruments}

for game in games_to_sample:
    sample_conditional_soloists(conditioned_models, game, outdir=os.path.join(model_weights_dir, 'samples'))

### Data exploration

In [None]:
midi_dirpath = 'nesmdb_midi/'
midi_train_dirpath = os.path.join(midi_dirpath, 'train')
midi_test_dirpath = os.path.join(midi_dirpath, 'test')
midi_val_dirpath = os.path.join(midi_dirpath, 'valid')
midi_train_filesnames = os.listdir(midi_train_dirpath)
midi_test_filesnames = os.listdir(midi_test_dirpath)
midi_val_filenames = os.listdir(midi_val_dirpath)

midi_train_filepaths = [os.path.join(midi_train_dirpath, filename) for filename in midi_train_filesnames]
midi_test_filepaths = [os.path.join(midi_test_dirpath, filename) for filename in midi_test_filesnames]
midi_val_filepaths = [os.path.join(midi_val_dirpath, filename) for filename in midi_val_filenames]
all_filepaths = midi_train_filepaths + midi_test_filepaths + midi_val_filepaths

In [None]:
import pandas as pd

program_counts = Counter()
num_programs_counts = Counter()

data = []

unique_programs = set()

note_counts = Counter()
num_notes_counts = Counter()
for i, filepath in enumerate(all_filepaths):
    try:
        midi = pretty_midi.PrettyMIDI(filepath)

        length = midi.get_end_time()
        instruments = midi.instruments
        num_instruments = len(instruments)


        num_notes = sum(len(instr.notes) for instr in instruments)
        num_notes_counts[num_notes] += 1
        

        for instrument in instruments:
            program_counts[instrument.program] += 1
            unique_programs.add(instrument.program)

            for note in instrument.notes:
                note_counts[note.pitch] += 1
                
        num_programs_counts[len(instruments)] += 1
    
        data.append({
            "length_sec": length,
            "over_10_sec": length > 10,
            "note_count": num_notes
        })

    except Exception as e:
        print(f"Error processing {filepath}: {e}")

    print_progress_bar(i+1, len(all_filepaths), prefix='Loading all MIDI files')

In [None]:
df = pd.DataFrame(data)
all_instruments = []
for number_of_instruments, count in num_programs_counts.items():
    all_instruments += [number_of_instruments] * count

summary = {
    "Total MIDI files": len(df),
    "Total number of notes": df["note_count"].sum(),
    "Total length (hours)": round(df["length_sec"].sum()/60/60, 1),
    "Average file duration (sec)": round(df["length_sec"].mean(), 1),
    "MIDI files > 10s": df["over_10_sec"].sum(),
    "MIDI files > 45s": df[df["length_sec"] > 45].shape[0],
    "Unique instruments": len(unique_programs),
    "Average number of instruments per file": round(np.mean(all_instruments), 1),
    }
pd.DataFrame.from_dict(summary, orient="index", columns=["Value"])

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

x = sorted(num_programs_counts.keys())
y = [num_programs_counts[k] for k in x]

plt.bar(x, y)
plt.title("Distribution of Number of Instruments per MIDI File")
plt.xlabel("Number of Instruments")
plt.ylabel("Count")
plt.xticks(x)  # Show each instrument count as a tick
plt.show()


In [None]:

# Optional: Convert to instrument names
from pretty_midi import program_to_instrument_name

program_names = {
    program_to_instrument_name(p): c for p, c in program_counts.items()
}
df_instr = pd.DataFrame(program_names.items(), columns=["Instrument", "Count"]).sort_values("Count", ascending=False)

sns.barplot(data=df_instr, x="Count", y="Instrument")
plt.title("Occurrences of Each Instrument")
plt.xlabel("Count")
plt.ylabel("Instrument")
plt.show()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(14, 5))

sns.histplot(df["length_sec"], bins=50, ax=axs[0])
axs[0].set_title("Distribution of MIDI Lengths")
axs[0].set_xlabel("Length (seconds)")
axs[0].set_ylabel("Count")

sns.histplot(df["note_count"], bins=50, ax=axs[1])
axs[1].set_title("Distribution of Note Counts")
axs[1].set_xlabel("Note Count")
axs[1].set_ylabel("Count")

plt.tight_layout()
plt.show()

In [None]:
df = pd.DataFrame(data)
all_instruments = []
for number_of_instruments, count in num_programs_counts.items():
    all_instruments += [number_of_instruments] * count

MIN_GAME_COUNT = 10
game_names = [extract_game_name(filepath) for filepath in all_filepaths]
game_name_dist = Counter(game_names)
valid_conditional_filepaths = [filepath for filepath in all_filepaths if game_name_dist[extract_game_name(filepath)] >= MIN_GAME_COUNT]
valid_game_names = set([extract_game_name(filepath) for filepath in valid_conditional_filepaths])

summary = {
    "Total MIDI files": len(df),
    "Total number of notes": df["note_count"].sum(),
    "Total length (hours)": round(df["length_sec"].sum()/60/60, 1),
    "Average file duration (sec)": round(df["length_sec"].mean(), 1),
    "MIDI files > 10s": df["over_10_sec"].sum(),
    "MIDI files > 45s": df[df["length_sec"] > 45].shape[0],
    "Unique instruments": len(unique_programs),
    "Average number of instruments per file": round(np.mean(all_instruments), 1),
    "Number of games": len(set(game_names)),
    "Games with at least 10 files": len(set(valid_game_names)),
    }
pd.DataFrame.from_dict(summary, orient="index", columns=["Value"])

In [None]:
from collections import Counter
import matplotlib.pyplot as plt

game_name_counts_dist = Counter(game_name_dist.values())
x = sorted(game_name_counts_dist.keys())  # x: number of files per game
y = [game_name_counts_dist[k] for k in x]  # y: how many games have that many files

# plt.figure(figsize=(10, 5))
plt.figure(figsize=(10, 5))

plt.bar(x, y)
plt.title("Distribution of Number of Files per Game")
plt.xlabel("Number of Files")
plt.ylabel("Games with This Many Files")
plt.xticks(x[::2])
plt.tight_layout()
plt.xlim(right=50)
plt.show()