In [2]:
import pretty_midi
import os
from collections import defaultdict
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
from miditok.pytorch_data import DatasetMIDI, DataCollator
import datetime

In [3]:
def print_progress_bar(iteration, total, prefix='', length=50):
    percent = ("{0:.1f}").format(100 * (iteration / float(total)))
    filled_length = int(length * iteration // total)
    bar = '█' * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% Complete', end='\r', flush=True)
    if iteration == total:
        print()

In [4]:
def valid_midi_files(filepaths):
    valid_files = []
    for i, filepath in enumerate(filepaths):
        try:
            midi_data = pretty_midi.PrettyMIDI(filepath)
            if len(midi_data.instruments) > 0:
                valid_files.append(filepath)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        print_progress_bar(i+1, len(filepaths), prefix='Validating MIDI files')
    return valid_files

In [5]:
midi_dirpath = 'nesmdb_midi/'
midi_train_dirpath = os.path.join(midi_dirpath, 'train')
midi_test_dirpath = os.path.join(midi_dirpath, 'test')
midi_val_dirpath = os.path.join(midi_dirpath, 'valid')
midi_train_filesnames = os.listdir(midi_train_dirpath)
midi_test_filesnames = os.listdir(midi_test_dirpath)
midi_val_filenames = os.listdir(midi_val_dirpath)

midi_train_filepaths = valid_midi_files([os.path.join(midi_train_dirpath, filename) for filename in midi_train_filesnames])
midi_test_filepaths = valid_midi_files([os.path.join(midi_test_dirpath, filename) for filename in midi_test_filesnames])
midi_val_filepaths = valid_midi_files([os.path.join(midi_val_dirpath, filename) for filename in midi_val_filenames])
all_filepaths = midi_train_filepaths + midi_test_filepaths + midi_val_filepaths

Validating MIDI files |--------------------------------------------------| 0.0% Complete

Error processing nesmdb_midi/train/122_FireEmblem_AnkokuRyutoHikarinoTsurugi_30_31EndingOmnibus.mid: MIDI file has a largest tick of 13007350, it is likely corrupt
Error processing nesmdb_midi/train/215_Magician_15_16EpiloguePart1.mid: MIDI file has a largest tick of 24797107, it is likely corrupt
Error processing nesmdb_midi/train/298_SolarJetman_HuntfortheGoldenWarpship_18_19LemonteGameplay.mid: MIDI file has a largest tick of 12682092, it is likely corrupt
Error processing nesmdb_midi/train/122_FireEmblem_AnkokuRyutoHikarinoTsurugi_28_29EndingOmnibusBallad.mid: MIDI file has a largest tick of 17014635, it is likely corrupt
Error processing nesmdb_midi/train/215_Magician_08_09MountVunarCavernsAbadonsCastle.mid: MIDI file has a largest tick of 16907305, it is likely corrupt
Error processing nesmdb_midi/train/405_ZombieNation_03_04VergeofDangerRoundSelect.mid: MIDI file has a largest tick of 18033915, it is likely corrupt
Error processing nesmdb_midi/train/104_FamicomJumpII_Saikyono7_n

In [6]:
TIME_SHIFT_RESOLUTION = 0.01  # 50 ms
MAX_SHIFT_STEPS = 100  # Max 5 seconds
BEGINNING_OF_SONG_TOKEN = '<BOS>'
END_OF_SONG_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'
VOCABULARY = dict()
index = 0

for special_token in [PAD_TOKEN, BEGINNING_OF_SONG_TOKEN, END_OF_SONG_TOKEN]:
    VOCABULARY[special_token] = index
    index += 1

for time_shift in range(1, MAX_SHIFT_STEPS + 1):
    VOCABULARY[f'time_shift_{time_shift}'] = index
    index += 1

for action in ["note_on", "note_off"]:
    for pitch in range(128):
        for program in [80, 81, 38, 121]:
            VOCABULARY[f'{action}_{pitch}_instrument_{program}'] = index
            index += 1

print(f'Vocabulary size: {len(VOCABULARY)}')
print(f'Vocabulary: {VOCABULARY}')

Vocabulary size: 1127
Vocabulary: {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, 'time_shift_1': 3, 'time_shift_2': 4, 'time_shift_3': 5, 'time_shift_4': 6, 'time_shift_5': 7, 'time_shift_6': 8, 'time_shift_7': 9, 'time_shift_8': 10, 'time_shift_9': 11, 'time_shift_10': 12, 'time_shift_11': 13, 'time_shift_12': 14, 'time_shift_13': 15, 'time_shift_14': 16, 'time_shift_15': 17, 'time_shift_16': 18, 'time_shift_17': 19, 'time_shift_18': 20, 'time_shift_19': 21, 'time_shift_20': 22, 'time_shift_21': 23, 'time_shift_22': 24, 'time_shift_23': 25, 'time_shift_24': 26, 'time_shift_25': 27, 'time_shift_26': 28, 'time_shift_27': 29, 'time_shift_28': 30, 'time_shift_29': 31, 'time_shift_30': 32, 'time_shift_31': 33, 'time_shift_32': 34, 'time_shift_33': 35, 'time_shift_34': 36, 'time_shift_35': 37, 'time_shift_36': 38, 'time_shift_37': 39, 'time_shift_38': 40, 'time_shift_39': 41, 'time_shift_40': 42, 'time_shift_41': 43, 'time_shift_42': 44, 'time_shift_43': 45, 'time_shift_44': 46, 'time_shift_45': 47, 

In [7]:
ID_TO_TOKEN = {v: k for k, v in VOCABULARY.items()}

In [8]:
def midi_to_tokens(pm: pretty_midi.PrettyMIDI):
    events = []

    for instrument in pm.instruments:
        for note in instrument.notes:
            events.append((note.start, f'note_on_{note.pitch}_instrument_{instrument.program}'))
            events.append((note.end, f'note_off_{note.pitch}_instrument_{instrument.program}'))
    
    events.sort()  # Sort by time

    tokens = []
    last_time = 0.0
    for time, event in events:
        delta = time - last_time
        steps = round(delta / TIME_SHIFT_RESOLUTION)

        while steps > 0:
            shift = min(steps, MAX_SHIFT_STEPS)
            tokens.append(f'time_shift_{shift}')
            steps -= shift
        
        tokens.append(event)
        last_time = time
    return [BEGINNING_OF_SONG_TOKEN] + tokens + [END_OF_SONG_TOKEN]

def tokens_to_midi(tokens):
    pm = pretty_midi.PrettyMIDI()
    instruments = dict()
    active_notes = dict()

    current_time = 0.0
    for token in tokens:
        if token.startswith('time_shift_'):
            shift_steps = int(token.split('_')[-1])
            current_time += shift_steps * TIME_SHIFT_RESOLUTION
        elif token.startswith('note_on_'):
            pitch = int(token.split('_')[2])
            instrument = int(token.split('_')[-1])
            active_notes[(pitch, instrument)] = current_time
        elif token.startswith('note_off_'):
            pitch = int(token.split('_')[2])
            instrument = int(token.split('_')[-1])
            if (pitch, instrument) not in active_notes:
                print(f"Warning: Note off for {pitch} on instrument {instrument} without matching note on.")
                continue
            start_time = active_notes[(pitch, instrument)]

            if instrument not in instruments:
                instruments[instrument] = pretty_midi.Instrument(program=instrument)

            note = pretty_midi.Note(
                velocity=100, pitch=pitch, start=start_time, end=current_time
            )

            instruments[instrument].notes.append(note)
            
    for instrument in instruments.values():
        pm.instruments.append(instrument)
    
    return pm

In [9]:
midi = pretty_midi.PrettyMIDI(all_filepaths[0])
tokens = midi_to_tokens(midi)
print(tokens[:30])
midi_reconstructed = tokens_to_midi(tokens)
# midi_reconstructed.write('reconstructed_midi.mid')
# midi.write('original_midi.mid')

['<BOS>', 'note_on_70_instrument_81', 'note_on_62_instrument_80', 'note_on_58_instrument_38', 'time_shift_37', 'note_off_58_instrument_38', 'time_shift_1', 'note_off_70_instrument_81', 'note_off_62_instrument_80', 'time_shift_2', 'note_on_70_instrument_81', 'note_on_62_instrument_80', 'note_on_58_instrument_38', 'time_shift_7', 'note_off_58_instrument_38', 'time_shift_1', 'note_off_70_instrument_81', 'note_off_62_instrument_80', 'time_shift_2', 'note_on_70_instrument_81', 'note_on_62_instrument_80', 'note_on_58_instrument_38', 'time_shift_7', 'note_off_58_instrument_38', 'time_shift_1', 'note_off_70_instrument_81', 'note_off_62_instrument_80', 'time_shift_2', 'note_on_70_instrument_81', 'note_on_62_instrument_80']


In [10]:
def load_sequences(filepaths):
    sequences = []
    for i, filepath in enumerate(filepaths):
        pm = pretty_midi.PrettyMIDI(filepath)
        tokens = midi_to_tokens(pm)
        if not tokens:
            raise ValueError(f'No tokens generated for {filepath}')
        sequences.append([VOCABULARY[token] for token in tokens])
        print_progress_bar(i+1, len(filepaths), prefix='Loading sequences')
    return sequences

class MIDITokenDataset(Dataset):
    def __init__(self, sequences, seq_length=512):
        self.inputs = []
        self.targets = []
        for ind, seq in enumerate(sequences):
            num_chunks = len(seq) // (seq_length + 1)
            for chunk in range(num_chunks + 1):
                chunk_start = chunk * (seq_length + 1)
                chunk_end = chunk_start + seq_length + 1
                chunk = seq[chunk_start:chunk_end]

                if len(chunk) < seq_length / 4:
                    continue

                input = chunk[:-1]
                input = np.pad(input, (0, seq_length - len(input)), constant_values=VOCABULARY[PAD_TOKEN])
                
                target = chunk[1:]
                target = np.pad(target, (0, seq_length - len(target)), constant_values=VOCABULARY[PAD_TOKEN])
                
                self.inputs.append(torch.tensor(input, dtype=torch.long))
                self.targets.append(torch.tensor(target, dtype=torch.long))
            
            print_progress_bar(ind+1, len(sequences), prefix='Processing sequences')

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [11]:
# Load and convert
train_sequences = load_sequences(midi_train_filepaths)
val_sequences = load_sequences(midi_val_filepaths)

Loading sequences |██████████████████████████████████████████████████| 100.0% Complete
Loading sequences |██████████████████████████████████████████████████| 100.0% Complete


In [12]:

def print_sequence_percentiles(sequences, prefix=''):
    print(f'{prefix} sequences: {len(sequences)}')
    print(f"{prefix} 90th percile length: {np.percentile([len(seq) for seq in sequences], 90)}")
    print(f"{prefix} 50th percile length: {np.percentile([len(seq) for seq in sequences], 50)}")
    print(f"{prefix} 25th percile length: {np.percentile([len(seq) for seq in sequences], 25)}")
    print(f"{prefix} 10th percile length: {np.percentile([len(seq) for seq in sequences], 10)}")
    print(f"{prefix} max train sequence length: {max(len(seq) for seq in sequences)}")
    print(f"{prefix} min train sequence length: {min(len(seq) for seq in sequences)}")

print_sequence_percentiles(train_sequences, prefix='Train')
print_sequence_percentiles(val_sequences, prefix='Validation')

Train sequences: 4470
Train 90th percile length: 4361.699999999999
Train 50th percile length: 944.0
Train 25th percile length: 286.25
Train 10th percile length: 99.0
Train max train sequence length: 27506
Train min train sequence length: 7
Validation sequences: 402
Validation 90th percile length: 4329.300000000001
Validation 50th percile length: 682.0
Validation 25th percile length: 176.0
Validation 10th percile length: 73.0
Validation max train sequence length: 14466
Validation min train sequence length: 14


In [13]:
# Create datasets
train_dataset = MIDITokenDataset(train_sequences, seq_length=512)
val_dataset = MIDITokenDataset(val_sequences, seq_length=512)
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(val_dataset)}')

Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Train dataset size: 16287
Validation dataset size: 1328


In [13]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(MusicRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch_size, seq_length)
        x = self.embedding(x)  # (batch_size, seq_length, embedding_dim)
        out, hidden = self.rnn(x, hidden)  # out: (batch_size, seq_length, hidden_dim)
        out = self.fc(out)  # (batch_size, seq_length, vocab_size)
        return out, hidden

In [63]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from music_transformer_fixed import MusicTransformer


def train_transformer(model, train_loader, val_loader, vocab_size, 
                     num_epochs=100, lr=0.001, device='cuda'):
    """
    Training function adapted for MusicTransformer with progress bars and batch size handling
    """
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        mems = None

        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)

            # reset mems at sequence boundaries you care about
            # e.g. if your loader shuffles songs:  mems = None

            outputs, mems = model(inputs, mems=mems)
            loss = criterion(outputs.view(-1, vocab_size),
                            targets.view(-1))

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            # detach AFTER the backward pass
            if mems is not None:
                mems = [m.detach() for m in mems]

            total_train_loss += loss.item()
            print_progress_bar(i+1, len(train_loader), prefix='Training...')
        
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            mems = None
            for i, batch in enumerate(val_loader):
                inputs, targets = batch
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                
                
                outputs, mems = model(inputs, mems=mems)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)
                
                loss = criterion(outputs.view(-1, vocab_size),
                         targets.view(-1))
                if mems is not None and mems[0].size(1) != inputs.size(0):
                    mems = None
                total_val_loss += loss.item()
                
                print_progress_bar(i+1, len(val_loader), prefix='Validating...')
        
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f}")


def sample_transformer(model, start_token, max_length=1024, 
                      temperature=1.0, device='cuda'):
    """
    Generate music using the transformer model
    """
    model = model.to(device)
    model.eval()
    
    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)
    
    mems = None  # Initialize memory
    
    with torch.no_grad():
        for _ in range(max_length):
            # Get output and update memory
            output, mems = model(input_token, mems=mems)
            output = output[:, -1, :]  # Take the last output
            output = output / temperature  # Adjust randomness
            
            probs = torch.nn.functional.softmax(output, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1).item()
            generated.append(next_token)
            
            # Check for end token (you'll need to define END_TOKEN)
            # if next_token == END_TOKEN:
            #     break
            
            input_token = torch.tensor([[next_token]], device=device)
    print(generated)
    return generated


# Example usage
xk = "__main__"
if xk == "__main__":
    # Configuration matching LakhNES
    vocab_size = 1127  # From your vocabulary
    
    # Create model with LakhNES configuration
    # model = MusicTransformer(
    #     vocab_size=vocab_size,
    #     d_model=512,
    #     n_head=8,
    #     d_head=64,
    #     d_inner=2048,
    #     n_layer=12,
    #     dropout=0.1,
    #     tgt_len=512,
    #     mem_len=512,
    #     tie_weight=True,
    #     pre_lnorm=False
    # )
    
    # # You can also create a smaller model for testing
    # small_model = MusicTransformer(
    #     vocab_size=vocab_size,
    #     d_model=256,
    #     n_head=4,
    #     d_head=64,
    #     d_inner=1024,
    #     n_layer=6,
    #     dropout=0.1,
    #     tgt_len=256,
    #     mem_len=256
    # )

    tiny_model = MusicTransformer(
        vocab_size=vocab_size,
        d_model=256,  # Keep same as small_model
        n_head=4,     # Keep same as small_model
        d_head=64,    # Keep same as small_model
        d_inner=512, # Keep same as small_model
        n_layer=1,    # Reduced from 6
        dropout=0.1,
        tgt_len=256,  # Keep same as small_model
        mem_len=256   # Keep same as small_model
    )

    # print(f"Full model parameters: {sum(p.numel() for p in model.parameters()):,}")
    # print(f"Small model parameters: {sum(p.numel() for p in small_model.parameters()):,}")
    print(f"Tiny model parameters: {sum(p.numel() for p in tiny_model.parameters()):,}")
    
    # To use in your notebook, you would:
    # 1. Import the MusicTransformer class
    # 2. Replace MusicRNN with MusicTransformer
    # 3. Use the adapted training and sampling functions 

Tiny model parameters: 946,791


In [60]:
batch_size = 32  # You can adjust this depending on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')
train_transformer(tiny_model, train_loader, val_loader, vocab_size=len(VOCABULARY), device=device)

Using device: mps
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 1/100 | Train Loss: 4.2405 | Val Loss: 4.2089
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 2/100 | Train Loss: 3.8382 | Val Loss: 4.1873
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 3/100 | Train Loss: 3.7976 | Val Loss: 4.1831
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 4/100 | Train Loss: 3.7781 | Val Loss: 4.1647
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████

In [61]:
# Copy and paste this code into a new cell in your Jupyter notebook
# after the cell where you sample from small_transformer

# Generate music from the transformer and save as MIDI
print("Generating music from the transformer...")
start_token = VOCABULARY[BEGINNING_OF_SONG_TOKEN]

# Generate sequence using the existing sample_transformer function
generated_sequence = sample_transformer(
    loaded_model, 
    start_token=start_token, 
    max_length=4092, 
    temperature=0.8, 
    device=device
)

# Convert token IDs back to token strings
generated_tokens = [ID_TO_TOKEN[token] for token in generated_sequence]

# Check if we hit the end token
if VOCABULARY[END_OF_SONG_TOKEN] in generated_sequence:
    end_idx = generated_sequence.index(VOCABULARY[END_OF_SONG_TOKEN])
    print(f"Hit end token at position {end_idx}")
else:
    print("Generated full sequence without hitting end token")

# Convert tokens to MIDI
print("\nConverting to MIDI...")
generated_midi = tokens_to_midi(generated_tokens)

# Save the MIDI file
output_filename = 'generated_transformer_music.mid'
generated_midi.write(output_filename)
print(f"\nGenerated MIDI saved as: {output_filename}")

# Display statistics about the generated music
print(f"\nGenerated sequence length: {len(generated_sequence)} tokens")
print(f"Number of instruments: {len(generated_midi.instruments)}")

if generated_midi.instruments:
    total_notes = sum(len(inst.notes) for inst in generated_midi.instruments)
    print(f"Total notes: {total_notes}")
    
    duration = generated_midi.get_end_time()
    print(f"Duration: {duration:.2f} seconds ({duration/60:.2f} minutes)")
    
    # Show instrument breakdown
    print("\nInstruments used:")
    instrument_names = {
        80: 'Square Lead', 
        81: 'Saw Lead', 
        38: 'Synth Bass', 
        121: 'Reverse Cymbal'
    }
    for inst in generated_midi.instruments:
        inst_name = instrument_names.get(inst.program, f'Program {inst.program}')
        print(f"  - {inst_name}: {len(inst.notes)} notes")

# Show sample of generated tokens
print(f"\nFirst 20 tokens:")
print(generated_tokens[:20])

print(f"\nLast 10 tokens:")
print(generated_tokens[-10:]) 

Generating music from the transformer...
[1, 162, 16, 162, 4, 162, 4, 162, 3, 626, 138, 277, 3, 674, 166, 841, 3, 839, 371, 883, 384, 289, 138, 4, 162, 4, 919, 411, 4, 785, 4, 793, 4, 265, 265, 777, 775, 4, 388, 900, 4, 392, 162, 9, 813, 10, 765, 3, 793, 8, 678, 162, 4, 796, 3, 162, 4, 285, 9, 138, 3, 887, 5, 297, 297, 3, 883, 363, 344, 3, 875, 471, 983, 379, 891, 383, 895, 5, 855, 4, 678, 6, 170, 682, 170, 3, 809, 3, 880, 4, 915, 4, 848, 4, 801, 7, 899, 4, 979, 3, 923, 9, 801, 4, 134, 4, 423, 935, 4, 920, 875, 4, 392, 12, 166, 7, 443, 3, 824, 4, 154, 4, 912, 404, 162, 317, 4, 400, 166, 285, 10, 3, 857, 8, 837, 4, 891, 4, 452, 964, 456, 4, 820, 4, 779, 263, 666, 6, 903, 387, 361, 8, 682, 162, 4, 861, 4, 313, 4, 388, 4, 733, 213, 4, 372, 154, 4, 408, 4, 396, 4, 158, 4, 781, 4, 924, 4, 439, 4, 416, 4, 900, 7, 843, 363, 130, 4, 455, 4, 162, 4, 912, 412, 134, 646, 4, 388, 16, 808, 809, 4, 166, 297, 4, 883, 371, 4, 900, 4, 380, 9, 146, 4, 856, 348, 154, 666, 4, 658, 150, 4, 908, 4, 316, 3, 

In [46]:
def save_transformer_model(model, optimizer, epoch, train_loss, val_loss, filepath='transformer_checkpoint.pt'):
    """
    Save model checkpoint with all necessary information for resuming training
    """
    # Get d_inner from the first layer
    first_layer = model.layers[0]
    d_inner = first_layer.d_inner
    
    # Check if weights are tied
    tie_weight = True  # Default assumption for MusicTransformer
    if hasattr(model, 'out_layer') and hasattr(model.out_layer, 'weight'):
        if hasattr(model.word_emb, 'weight'):
            # Check if they point to the same tensor
            tie_weight = model.out_layer.weight.data_ptr() == model.word_emb.weight.data_ptr()
    
    # Determine pre_lnorm by checking layer structure
    pre_lnorm = False  # Default for most transformer implementations
    
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'model_config': {
            'vocab_size': model.word_emb.num_embeddings,
            'd_model': model.d_model,
            'n_head': model.n_head,
            'd_head': model.d_head,
            'd_inner': d_inner,
            'n_layer': model.n_layer,
            'tgt_len': model.tgt_len,
            'mem_len': model.mem_len,
            'dropout': model.drop.p,
            'tie_weight': tie_weight,
            'pre_lnorm': pre_lnorm
        }
    }
    torch.save(checkpoint, filepath)
    print(f"Model saved to {filepath}")
    print(f"Configuration saved: n_layer={model.n_layer}, d_model={model.d_model}, d_inner={d_inner}")

def load_transformer_model(filepath='transformer_checkpoint.pt', device='cuda'):
    """
    Load model checkpoint and create model with saved configuration
    """
    checkpoint = torch.load(filepath, map_location=device, weights_only=False)
    
    if 'model_config' in checkpoint:
        config = checkpoint['model_config']
    else:
        raise ValueError("No model configuration found in checkpoint")
    
    # Create model
    model = MusicTransformer(**config)
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    
    # Create and load optimizer
    optimizer = optim.Adam(model.parameters())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f"Model loaded from {filepath}")
    print(f"Configuration: n_layer={config['n_layer']}, d_model={config['d_model']}, d_inner={config['d_inner']}")
    print(f"Resumed from epoch {checkpoint['epoch']} with train loss: {checkpoint['train_loss']:.4f}, val loss: {checkpoint['val_loss']:.4f}")
    
    return model, optimizer, checkpoint['epoch'], checkpoint['train_loss'], checkpoint['val_loss']

# Simplified functions for common use cases
def quick_save(model, filepath='model_checkpoint.pt', epoch=0):
    """Quick save without optimizer state"""
    optimizer = optim.Adam(model.parameters())
    save_transformer_model(model, optimizer, epoch, 0.0, 0.0, filepath)

def quick_load(filepath='model_checkpoint.pt', device='cuda'):
    """Quick load just the model"""
    model, _, _, _, _ = load_transformer_model(filepath, device)
    return model



# Modified training function that supports resuming
def train_transformer_resume(model, train_loader, val_loader, vocab_size, 
                            num_epochs=20, lr=0.001, device='cuda', 
                            start_epoch=0, optimizer=None, checkpoint_path='transformer_checkpoint.pt',
                            save_every_epoch=False):
    """
    Training function with checkpoint saving and resume capability
    """
    model = model.to(device)
    pad_idx = VOCABULARY[PAD_TOKEN]
    criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
    
    if optimizer is None:
        optimizer = optim.Adam(model.parameters(), lr=lr)
    
    best_val_loss = float('inf')
    
    for epoch in range(start_epoch, num_epochs):
        # Training
        model.train()
        total_train_loss = 0
        mems = None
        
        for i, batch in enumerate(train_loader):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Reset memory if batch size changes
            if mems is not None and mems[0].size(1) != inputs.size(0):
                mems = None
            
            optimizer.zero_grad()
            
            outputs, mems = model(inputs, mems=mems)
            
            if mems is not None:
                mems = [m.detach() for m in mems]
            
            outputs = outputs.reshape(-1, vocab_size)
            targets = targets.reshape(-1)
            
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
            print_progress_bar(i+1, len(train_loader), prefix='Training...')
        
        avg_train_loss = total_train_loss / len(train_loader)
        
        # Validation
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            mems = None
            for i, batch in enumerate(val_loader):
                inputs, targets = batch
                inputs = inputs.to(device)
                targets = targets.to(device)
                
                if mems is not None and mems[0].size(1) != inputs.size(0):
                    mems = None
                
                outputs, mems = model(inputs, mems=mems)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)
                
                loss = criterion(outputs, targets)
                total_val_loss += loss.item()
                
                print_progress_bar(i+1, len(val_loader), prefix='Validating...')
        
        avg_val_loss = total_val_loss / len(val_loader)
        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {avg_train_loss:.4f} | "
              f"Val Loss: {avg_val_loss:.4f}")
        
        # Save checkpoint
        if save_every_epoch or avg_val_loss < best_val_loss:
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                print(f"New best validation loss: {best_val_loss:.4f}")
            save_transformer_model(model, optimizer, epoch+1, avg_train_loss, avg_val_loss, checkpoint_path)


# Save your model
print("Saving tiny_model...")
save_transformer_model(
    loaded_model, 
    optimizer=optim.Adam(tiny_model.parameters()), 
    epoch=100,
    train_loss=2.3563,
    val_loss=2.5633,
    filepath='tiny_transformer_checkpoint.pt'
)

# Load it back
print("\nLoading model...")
loaded_model, loaded_optimizer, start_epoch, last_train_loss, last_val_loss = load_transformer_model(
    'tiny_transformer_checkpoint.pt', 
    device=device
)

# Test it works
print("\nTesting loaded model...")
with torch.no_grad():
    test_input = torch.randint(0, 1127, (1, 10)).to(device)
    test_output, _ = loaded_model(test_input)
    print(f"Model output shape: {test_output.shape}")

# train_transformer_resume(
#     loaded_model, 
#     train_loader, 
#     val_loader, 
#     vocab_size=len(VOCABULARY), 
#     num_epochs=start_epoch + 50,  # No curly braces
#     device=device,                 # No curly braces
#     start_epoch=start_epoch,       # No curly braces
#     optimizer=loaded_optimizer,
#     checkpoint_path='tiny_transformer_checkpoint.pt'
# )

Saving tiny_model...
Model saved to tiny_transformer_checkpoint.pt
Configuration saved: n_layer=3, d_model=256, d_inner=1024

Loading model...
Model loaded from tiny_transformer_checkpoint.pt
Configuration: n_layer=3, d_model=256, d_inner=1024
Resumed from epoch 100 with train loss: 2.3563, val loss: 2.5633

Testing loaded model...
Model output shape: torch.Size([1, 10, 1127])


In [14]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=20, lr=0.001, device='cuda'):
    time_start = time.time()
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # --------- Training ---------
        model.train()
        total_train_loss = 0

        for i, batch in enumerate(train_loader):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            # batch = batch['input_ids'].to(device)  # (batch_size, seq_length)

            # inputs = batch[:, :-1]
            # targets = batch[:, 1:]

            optimizer.zero_grad()
            outputs, _ = model(inputs)

            outputs = outputs.reshape(-1, vocab_size)
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            print_progress_bar(i+1, len(train_loader), prefix='Training...')

        avg_train_loss = total_train_loss / len(train_loader)

        # --------- Validation ---------
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                # batch = batch['input_ids'].to(device)

                # inputs = batch[:, :-1]
                # targets = batch[:, 1:]

                outputs, _ = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)

                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

                print_progress_bar(i+1, len(val_loader), prefix='Validating...')

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [16]:
batch_size = 8  # You can adjust this depending on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'mps'
print(f'Using device: {device}')
model = MusicRNN(vocab_size=len(VOCABULARY), embedding_dim=256, hidden_dim=512, num_layers=2)
train(model, train_loader, val_loader, vocab_size=len(VOCABULARY), device=device)

Using device: mps
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 1/20 | Train Loss: 7.0168 | Val Loss: 6.7814
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 2/20 | Train Loss: 6.7814 | Val Loss: 6.3470
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 3/20 | Train Loss: 6.3470 | Val Loss: 5.9521
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 4/20 | Train Loss: 5.9521 | Val Loss: 5.7736
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████

In [40]:
def sample(model, start_token, max_length=100, temperature=1.0, device='cuda'):
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)  # (1, 1)

    hidden = None

    for _ in range(max_length):
        output, hidden = model(input_token, hidden)  # output: (1, 1, vocab_size)
        output = output[:, -1, :]  # take the last output
        output = output / temperature  # adjust randomness

        probs = F.softmax(output, dim=-1)  # (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        if next_token == VOCABULARY[END_OF_SONG_TOKEN] or VOCABULARY[PAD_TOKEN]: # reach end of sequence
          break

        input_token = torch.tensor([[next_token]], device=device)

    return generated



In [None]:
start_token = VOCABULARY[BEGINNING_OF_SONG_TOKEN]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generated_sequence = sample(model, start_token, max_length=1024, device=device)
generated_tokens = [ID_TO_TOKEN[token] for token in generated_sequence]
tokens_reconstructed = tokens_to_midi(generated_tokens)
tokens_reconstructed.write('generated_midi.mid')

### Train Soloists

In [10]:
def valid_soloist_midi_files(filepaths, instrument_program):
    valid_files = []
    for i, filepath in enumerate(filepaths):
        try:
            midi_data = pretty_midi.PrettyMIDI(filepath)
            if len(midi_data.instruments) > 0 and any(instrument.program == instrument_program for instrument in midi_data.instruments):
                valid_files.append(filepath)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        print_progress_bar(i+1, len(filepaths), prefix='Validating MIDI files')
    return valid_files

In [8]:
unique_instruments = [80, 81, 38, 121]

In [11]:
soloist_80_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=80)
soloist_81_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=81)
soloist_38_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=38)
soloist_121_midi_train_filepaths = valid_soloist_midi_files(midi_train_filepaths, instrument_program=121)

soloist_80_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=80)
soloist_81_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=81)
soloist_38_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=38)
soloist_121_midi_val_filepaths = valid_soloist_midi_files(midi_val_filepaths, instrument_program=121)

Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete
Validating MIDI files |██████████████████████████████████████████████████| 100.0% Complete


In [22]:
def get_soloist_vocabulary():
    vocabulary = dict()
    index = 0

    for special_token in [PAD_TOKEN, BEGINNING_OF_SONG_TOKEN, END_OF_SONG_TOKEN]:
        vocabulary[special_token] = index
        index += 1

    for time_shift in range(1, MAX_SHIFT_STEPS + 1):
        vocabulary[f'time_shift_{time_shift}'] = index
        index += 1

    for action in ["note_on", "note_off"]:
        for pitch in range(128):
            vocabulary[f'{action}_{pitch}'] = index
            index += 1
    return vocabulary

In [42]:
def get_soloist_id_to_token():
    vocabulary = get_soloist_vocabulary()
    return {v: k for k, v in vocabulary.items()}

In [102]:
def soloist_midi_to_tokens(pm: pretty_midi.PrettyMIDI, instrument_program):
    events = []

    for instrument in pm.instruments:
        if instrument.program == instrument_program:
            for note in instrument.notes:
                events.append((note.start, f'note_on_{note.pitch}'))
                events.append((note.end, f'note_off_{note.pitch}'))
        
            break
        
    
    events.sort()  # Sort by time

    tokens = []
    last_time = 0.0
    for time, event in events:
        delta = time - last_time
        steps = max(round(delta / TIME_SHIFT_RESOLUTION), 1) # force at least 1 step (no side by side notes)

        while steps > 0:
            shift = min(steps, MAX_SHIFT_STEPS)
            tokens.append(f'time_shift_{shift}')
            steps -= shift
        
        tokens.append(event)
        last_time = time
    return [BEGINNING_OF_SONG_TOKEN] + tokens + [END_OF_SONG_TOKEN]

def soloist_tokens_to_midi(tokens, instrument_program):
    pm = pretty_midi.PrettyMIDI()
    instrument = pretty_midi.Instrument(program=instrument_program)

    active_notes = dict()
    active_pitch = None
    active_start = None

    current_time = 0.0
    for token in tokens:
        # print(current_time, token)
        if token.startswith('time_shift_'):
            shift_steps = int(token.split('_')[-1])
            current_time += shift_steps * TIME_SHIFT_RESOLUTION
        elif token.startswith('note_on_'):
            if active_pitch is not None:
                print(f"Warning: Skipping {token}, other note: {active_pitch} is still active.")
                continue 
            pitch = int(token.split('_')[2])
            # active_notes[pitch] = current_time
            active_pitch = pitch
            active_start = current_time
        elif token.startswith('note_off_'):
            pitch = int(token.split('_')[2])
            if pitch != active_pitch:
                print(f"Warning: Note off for {pitch} without matching note on.")
                continue

            if current_time > active_start:
                note = pretty_midi.Note(
                    velocity=100, pitch=pitch, start=active_start, end=current_time
                )

                instrument.notes.append(note)
            else:
                print(f"Warning: Note off for {pitch} at {current_time} note after note on at {active_start}. Ignoring.")
            # del active_notes[pitch]
            active_pitch = None
            active_start = None

    pm.instruments.append(instrument)
    
    return pm

In [19]:
midi = pretty_midi.PrettyMIDI(all_filepaths[0])
tokens_80 = soloist_midi_to_tokens(midi, 80)
tokens_81 = soloist_midi_to_tokens(midi, 81)
tokens_38 = soloist_midi_to_tokens(midi, 38)
tokens_121 = soloist_midi_to_tokens(midi, 121)

midi_reconstructed_80 = soloist_tokens_to_midi(tokens_80, instrument_program=80)
midi_reconstructed_81 = soloist_tokens_to_midi(tokens_81, instrument_program=81)
midi_reconstructed_38 = soloist_tokens_to_midi(tokens_38, instrument_program=38)
midi_reconstructed_121 = soloist_tokens_to_midi(tokens_121, instrument_program=121)

# midi_reconstructed_80.write('reconstructed_midi_80.mid')
# midi_reconstructed_81.write('reconstructed_midi_81.mid')
# midi_reconstructed_38.write('reconstructed_midi_38.mid')
# midi_reconstructed_121.write('reconstructed_midi_121.mid')

# midi.write('original_midi.mid')

In [23]:
def load_soloist_sequences(filepaths, instrument_program):
    vocabulary = get_soloist_vocabulary()
    sequences = []
    for i, filepath in enumerate(filepaths):
        pm = pretty_midi.PrettyMIDI(filepath)
        tokens = soloist_midi_to_tokens(pm, instrument_program)
        if not tokens:
            raise ValueError(f'No tokens generated for {filepath}')
        sequences.append([vocabulary[token] for token in tokens])
        print_progress_bar(i+1, len(filepaths), prefix='Loading sequences')
    return sequences

In [100]:
soloist_80_train_sequences = load_soloist_sequences(soloist_80_midi_train_filepaths, instrument_program=80)
soloist_81_train_sequences = load_soloist_sequences(soloist_81_midi_train_filepaths, instrument_program=81)
soloist_38_train_sequences = load_soloist_sequences(soloist_38_midi_train_filepaths, instrument_program=38)
soloist_121_train_sequences = load_soloist_sequences(soloist_121_midi_train_filepaths, instrument_program=121)

soloist_80_val_sequences = load_soloist_sequences(soloist_80_midi_val_filepaths, instrument_program=80)
soloist_81_val_sequences = load_soloist_sequences(soloist_81_midi_val_filepaths, instrument_program=81)
soloist_38_val_sequences = load_soloist_sequences(soloist_38_midi_val_filepaths, instrument_program=38)
soloist_121_val_sequences = load_soloist_sequences(soloist_121_midi_val_filepaths, instrument_program=121)

Loading sequences |██████████████████████████████████████████████████| 100.0% Complete
Loading sequences |██████████████████--------------------------------| 36.9% Complete

KeyboardInterrupt: 

In [34]:
def get_soloist_model():
    return MusicRNN(vocab_size=len(get_soloist_vocabulary()), embedding_dim=64, hidden_dim=512, num_layers=2)

In [None]:
def train_soloists(soloist_sequences: dict):
    models = {}
    for instrument_program, (train_sequences, val_sequences) in soloist_sequences.items():
        train_dataset = MIDITokenDataset(train_sequences[:2], seq_length=512)
        val_dataset = MIDITokenDataset(val_sequences[:2], seq_length=512)

        batch_size = 8
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        model = get_soloist_model()
        train(model, train_loader, val_loader, vocab_size=len(get_soloist_vocabulary()), device=device)
        models[instrument_program] = model
    return models


In [31]:
soloist_sequences = {
    80: (soloist_80_train_sequences, soloist_80_val_sequences),
    81: (soloist_81_train_sequences, soloist_81_val_sequences),
    38: (soloist_38_train_sequences, soloist_38_val_sequences),
    121: (soloist_121_train_sequences, soloist_121_val_sequences)
}
soloist_models = train_soloists(soloist_sequences)

Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 1/20 | Train Loss: 5.8776 | Val Loss: 5.7679
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 2/20 | Train Loss: 5.7679 | Val Loss: 5.5858
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 3/20 | Train Loss: 5.5858 | Val Loss: 5.1507
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 4/20 

In [32]:
def save_model_weights(models):
    for instrument_program, model in models.items():
        torch.save(model.state_dict(), f'soloist_model_{instrument_program}.pth')

In [33]:
save_model_weights(soloist_models)

In [51]:
def load_model_weight(path, model):
    model.load_state_dict(torch.load(path, map_location='cpu'))
    return model

In [113]:
def sample_soloists(models, outdir='./'):
    soloist_vocabulary = get_soloist_vocabulary()
    midis = dict()
    for instrument_program, model in models.items():
        model.eval()  # Set to eval mode for inference
        start_token = soloist_vocabulary[BEGINNING_OF_SONG_TOKEN]
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        generated_sequence = sample(model, start_token, max_length=1024, device=device)
        generated_tokens = [get_soloist_id_to_token()[token] for token in generated_sequence]
        tokens_reconstructed = soloist_tokens_to_midi(generated_tokens, instrument_program=instrument_program)
        midis[instrument_program] = tokens_reconstructed
    
    together_midi = pretty_midi.PrettyMIDI()
    for instrument_program, midi in midis.items():
        together_midi.instruments.extend(midi.instruments)

    os.makedirs(outdir, exist_ok=True)
    outpath = os.path.join(outdir, f'{datetime.datetime.now()}.mid')
    
    together_midi.write(outpath)


In [123]:
# models = [load_model_weight(f'soloist_model_{instrument_program}.pth', get_soloist_model()) for instrument_program in unique_instruments]
model_weights_dir = os.path.join('lstm_models', 'soloists', '2')
# models = {80: load_model_weight(os.path.join(model_weights_dir, 'soloist_model_80.pth'), get_soloist_model())}
models = {instr_program: load_model_weight(os.path.join(model_weights_dir, f'soloist_model_{instr_program}.pth'), get_soloist_model()) for instr_program in unique_instruments}
sample_soloists(models, outdir=os.path.join(model_weights_dir, 'samples'))

In [98]:
side_by_side_note_count = 0
total_tokens = 0
for i, sequence in enumerate(soloist_80_train_sequences):
    pairwise = zip(sequence[:-1], sequence[1:])
    total_tokens += len(sequence)
    for token1, token2 in pairwise:
        token1 = get_soloist_id_to_token()[token1]
        token2 = get_soloist_id_to_token()[token2]
        if token1.startswith('note_on_') and token2.startswith('note_off_'):
            if token1.split('_')[2] == token2.split('_')[2]:
                side_by_side_note_count += 1

    print_progress_bar(i+1, len(soloist_80_train_sequences), prefix='Counting side by side notes')
    
print(f'Side by side note count: {side_by_side_note_count}')
print(f'Total tokens: {total_tokens}')

Counting side by side notes |██████████████████████████████████████████████████| 100.0% Complete
Side by side note count: 121967
Total tokens: 2057224
