In [1]:
# !pip install miditoolkit 

In [1]:
import os
import random
from miditoolkit import MidiFile, Note, Instrument
import matplotlib.pyplot as plt
from collections import Counter
import subprocess
import datetime
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time


  from .autonotebook import tqdm as notebook_tqdm


### Utils

In [2]:
def print_progress_bar(iteration, total, prefix='', length=50):
    percent = ("{0:.1f}").format(100 * (iteration / float(total)))
    filled_length = int(length * iteration // total)
    bar = '█' * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% Complete', end='\r', flush=True)
    if iteration == total:
        print()

In [3]:
def to_midi_objs(filepaths: list, num_samples: int = 100) -> list: 
    sampled_filepaths = random.sample(filepaths, num_samples)
    midis = []
    for i, filepath in enumerate(sampled_filepaths):
        print_progress_bar(i + 1, num_samples, prefix='Converting .mid files to MidiFile')
        try:
            midi = MidiFile(filepath)
            midis.append(midi)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        
    return midis

### Data Exploration

In [4]:
midi_dirpath = 'nesmdb_midi/'
midi_train_dirpath = os.path.join(midi_dirpath, 'train')
midi_test_dirpath = os.path.join(midi_dirpath, 'test')
midi_val_dirpath = os.path.join(midi_dirpath, 'valid')
midi_train_filesnames = os.listdir(midi_train_dirpath)
midi_test_filesnames = os.listdir(midi_test_dirpath)
midi_val_filenames = os.listdir(midi_val_dirpath)

midi_train_filepaths = [os.path.join(midi_train_dirpath, filename) for filename in midi_train_filesnames]
midi_test_filepaths = [os.path.join(midi_test_dirpath, filename) for filename in midi_test_filesnames]
midi_val_filepaths = [os.path.join(midi_val_dirpath, filename) for filename in midi_val_filenames]
all_filepaths = midi_train_filepaths + midi_test_filepaths + midi_val_filepaths

In [5]:
print("num_train_files", len(midi_train_filepaths))
print("num_test_files", len(midi_test_filepaths))
print("num_val_files", len(midi_val_filepaths))
print("total_files", len(all_filepaths))

num_train_files 4502
num_test_files 373
num_val_files 403
total_files 5278


In [6]:
midis = to_midi_objs(all_filepaths, num_samples=len(all_filepaths))

Converting .mid files to MidiFile |██████████████████████████████████████████████████| 100.0% Complete


In [7]:
instrument_count_distribution = Counter(len(midi.instruments) for midi in midis)
unique_instruments = set([int(instrument.program) for midi in midis for instrument in midi.instruments])
unique_instruments_distribution = Counter([int(instrument.program) for midi in midis for instrument in midi.instruments])
instrument_sets_distribution = Counter(tuple(sorted([int(instrument.program) for instrument in midi.instruments])) for midi in midis)
number_of_tempo_changes = Counter([len(m.tempo_changes) for m in midis])
tempo_dist = Counter(c.tempo for m in midis for c in m.tempo_changes)
ticks_per_sec_dist = Counter(m.ticks_per_beat for m in midis)
print("Instrument count distribution:", dict(instrument_count_distribution))
print("Unique instruments:", unique_instruments)
print("Unique instruments distribution:", dict(unique_instruments_distribution))
print("Instrument set distribution:", dict(instrument_sets_distribution))
print("Number of tempo changes:", number_of_tempo_changes)
print("Tempo distribution:", tempo_dist)
print("ticks_per_sec_dist:", ticks_per_sec_dist)

Instrument count distribution: {2: 598, 3: 1730, 4: 2821, 1: 125, 0: 4}
Unique instruments: {80, 81, 38, 121}
Unique instruments distribution: {80: 5075, 38: 4676, 81: 4970, 121: 3074}
Instrument set distribution: {(38, 80): 126, (38, 80, 81): 1519, (38, 80, 81, 121): 2821, (38, 80, 121): 57, (80, 81, 121): 114, (80, 81): 382, (38, 81): 61, (81,): 28, (38, 81, 121): 40, (80,): 47, (121,): 13, (): 4, (38, 121): 15, (38,): 37, (80, 121): 9, (81, 121): 5}
Number of tempo changes: Counter({1: 5278})
Tempo distribution: Counter({120.0: 5278})
ticks_per_sec_dist: Counter({22050: 5278})


In [8]:
midis_with_instr = [m for m in midis if len(m.instruments)]
def piece_duration(midi_obj):
    ticks_per_beat = 22050
    bpm = 120
    bps = bpm / 60
    max_tick = max(note.end for inst in midi_obj.instruments for note in inst.notes)
    piece_duration_in_s = max_tick / ticks_per_beat / bps
    return piece_duration_in_s

durations = [piece_duration(m) for m in midis_with_instr]
print("max duration:", max(durations))
print("90th percentile duration:", np.percentile(durations, 90))

max duration: 1517.683560090703
90th percentile duration: 64.9059365079365


In [9]:
midis[1]

ticks per beat: 22050
max tick: 282514
tempo changes: 1
time sig: 2
key sig: 0
markers: 0
lyrics: False
instruments: 3

### Custom LSTM

In [10]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(MusicRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch_size, seq_length)
        x = self.embedding(x)  # (batch_size, seq_length, embedding_dim)
        out, hidden = self.rnn(x, hidden)  # out: (batch_size, seq_length, hidden_dim)
        out = self.fc(out)  # (batch_size, seq_length, vocab_size)
        return out, hidden

In [68]:
class CustomMIDITokenizer( ):
    def __init__(self, time_divisions=100, time_shift_increment=0.01, target_instrument=80):
        self.time_divisions = time_divisions
        self.time_shift_increment = time_shift_increment
        self.special_tokens = {
            "PAD": 0,
            "BOS": 1,
            "EOS": 2,
        }
        
        self.ticks_per_beat = 22050 # from data exploration
        self.bpm = 120 # from data exploration
        self.target_instrument = target_instrument
        
        self.note_on_offset = 3
        self.note_off_offset = self.note_on_offset + 128
        self.time_offset = self.note_off_offset + 128
        self.vocab_size = self.time_offset + self.time_divisions + 1

    def quantize_duration(self, delta_seconds):
        idx = int(delta_seconds / self.time_shift_increment)
        return min(idx, self.time_divisions - 1)  # Cap at 99

    def get_note_on_token(self, note):
        return self.note_on_offset + note

    def get_note_off_token(self, note):
        return self.note_off_offset + note

    def get_time_shift_token(self, quantized_delta):
        return self.time_offset + quantized_delta
        
        
    def get_quantized_time_shift_deltas(self, prev_event_time_in_s, next_event_time_in_s):
        quantized_times = []
        while next_event_time_in_s - prev_event_time_in_s > self.time_shift_increment:
            delta = next_event_time_in_s - prev_event_time_in_s
            prev_event_time_in_s += min(delta, self.time_divisions * self.time_shift_increment)
            quantized = self.quantize_duration(delta)
            quantized_times.append(quantized)
        return quantized_times
    
    def __getitem__(self, item):
        return self.special_tokens[item.split("_")[0]]

    @property
    def pad_token_id(self):
        return self.special_tokens["PAD"]
    
    def encode(self, midi_obj):
        ticks_per_beat = midi_obj.ticks_per_beat
        bps = self.bpm / 60

        targets = list(filter(lambda x: x.program == self.target_instrument, list(midi_obj.instruments)))

        max_tick = max(note.end for inst in midi_obj.instruments for note in inst.notes)
        piece_duration_in_s = max_tick / ticks_per_beat / bps

        note_events = []
        if len(targets):
            for note in targets[0].notes:
                start_time_in_s = note.start / ticks_per_beat / bps
                end_time_in_s = note.end / ticks_per_beat / bps
                if end_time_in_s - start_time_in_s >= self.time_shift_increment:
                    note_events.append((start_time_in_s, note.pitch, "start"))
                    note_events.append((end_time_in_s, note.pitch, "end"))
     
        
        note_events = sorted(note_events)
        
        debug_tokens = []
        prev_time = 0.0
        for time_start, pitch, action in note_events:
            times_shift_deltas = self.get_quantized_time_shift_deltas(prev_time, time_start)
            for delta in times_shift_deltas:
                ts_token = self.get_time_shift_token(delta)
                debug_tokens.append((ts_token, "time_shift_token"))
            prev_time = time_start

            if (action == "start"):
                debug_tokens.append((self.get_note_on_token(pitch), "note_on"))
            else:
                debug_tokens.append((self.get_note_off_token(pitch), "note_off"))

        
        # print(debug_tokens)
        tokens = [t[0] for t in debug_tokens]
        
        return tokens

    def decode(self, tokens, output_path):
        midi_obj = MidiFile()
        instrument = Instrument(program=self.target_instrument, is_drum=False)

        current_time = 0.0 # in seconds
        active_notes = {}

        ticks_per_second = self.ticks_per_beat * self.bpm / 60
        
        for token in tokens:
            if token == self.special_tokens["BOS"]:
                continue # handle BOS
                
            elif token == self.special_tokens["EOS"]: 
                continue # handle EOS
                
            elif token == self.special_tokens["PAD"]: # will model output PAD tokens?
                continue # handle PAD
                
            elif token < self.note_off_offset: # are note pitches 0 indexed?
                # handle token is a note start event
                pitch = token - self.note_on_offset
                active_notes[pitch] = current_time
                
            elif token < self.time_offset: # are note pitches 0 indexed?
               # handle token is a note end event
                pitch = token - self.note_off_offset
                if pitch in active_notes:
                    start = active_notes.pop(pitch)
                    end = current_time
                    note = Note(
                        pitch=pitch,
                        start=int(start * ticks_per_second),
                        end=int(end * ticks_per_second),
                        velocity=100,  # constant velocity
                    )
                    instrument.notes.append(note)
                
            elif token < self.vocab_size:
                # handle token is a time shift event
                shift_amount = (token - self.time_offset + 1) * 0.01
                current_time += shift_amount
                
            else:
                raise Exception("unknown token value: " + token)

        midi_obj.instruments.append(instrument)
        midi_obj.ticks_per_beat = self.ticks_per_beat
        midi_obj.dump(output_path)
        

In [67]:
target_instrument = 80 # unique instruments = 80, 81, 38, 121
tokenizer = CustomMIDITokenizer(target_instrument=target_instrument)
midi_obj = midis[150]
tokens = tokenizer.encode(midi_obj)
print(len(tokens))
tokenizer.decode(tokens, f"decoded_{target_instrument}.mid") 
midi_obj.dump("original.mid")

49


In [13]:
tokenizer = CustomMIDITokenizer()

lengths = []
failed_count = 0
for m in midis:
    try: 
        tokens = tokenizer.encode(m)
        lengths.append(len(tokens))
    except Exception:
        failed_count += 1

print("max length:", max(lengths))
print("90th percentile length:", np.percentile(lengths, 90))
print("failed midis: ", failed_count)

max length: 53791
90th percentile length: 922.6999999999998
failed midis:  4


In [23]:
class MIDIDataset(Dataset):
    def __init__(self, file_paths, tokenizer, max_seq_len, midi_objs):
        self.file_paths = file_paths
        self.midi_objs = midi_objs
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.bos_token = tokenizer["BOS_None"]
        self.eos_token = tokenizer["EOS_None"]

    def __len__(self):
        return len(self.midi_objs)
        

    def __getitem__(self, idx):
        midi_obj = self.midi_objs[idx]
        tokens = self.tokenizer.encode(midi_obj)
        tokens = [self.bos_token] + tokens + [self.eos_token]
        tokens = tokens[:self.max_seq_len]  # truncate if needed
        
        return torch.tensor(tokens, dtype=torch.long)

def safe_collate(batch, pad_token_id=0):
    # Remove None items (failed MIDI files)
    batch = [b for b in batch if b is not None]
    if not batch:
        print("NO NOT NONE BATCH!")
        return {"input_ids": torch.empty(0, dtype=torch.long)}

    lengths = [len(x) for x in batch]
    max_len = max(lengths)

    padded = torch.full((len(batch), max_len), pad_token_id, dtype=torch.long)
    for i, seq in enumerate(batch):
        padded[i, :len(seq)] = seq

    return {"input_ids": padded}

In [15]:
train_midi_objs = to_midi_objs(midi_train_filepaths, num_samples=len(midi_train_filepaths))
test_midi_objs = to_midi_objs(midi_test_filepaths, num_samples=len(midi_test_filepaths))

Converting .mid files to MidiFile |██████████████████████████████████████████████████| 100.0% Complete
Converting .mid files to MidiFile |██████████████████████████████████████████████████| 100.0% Complete


In [16]:
train_midi_objs = list(filter(lambda midi_obj: len(midi_obj.instruments) > 0, train_midi_objs))
test_midi_objs = list(filter(lambda midi_obj: len(midi_obj.instruments) > 0, test_midi_objs))

In [18]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=20, lr=0.001, device='cuda'):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # early stopping
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    patience = 2  # or any number of your choice
    
    for epoch in range(num_epochs):
        start_time = time.time()
        # --------- Training ---------
        model.train()
        total_train_loss = 0

        for i, batch in enumerate(train_loader):
            print_progress_bar(i + 1, len(train_loader), "training...")
            batch = batch['input_ids'].to(device)  # (batch_size, seq_length)
            
            if batch.ndim != 2:
                print(f"Malformed train batch at epoch {epoch}, step {i}: shape={batch.shape}")
            
            inputs = batch[:, :-1]
            targets = batch[:, 1:]

            optimizer.zero_grad()
            outputs, _ = model(inputs)
            outputs = outputs.reshape(-1, vocab_size)
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        avg_train_loss = total_train_loss / len(train_loader)

        # --------- Validation ---------
        model.eval()
        total_val_loss = 0
        total_correct = 0
        total_tokens = 0        
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                print_progress_bar(i + 1, len(val_loader), "validation...")
                batch = batch['input_ids'].to(device)
                if batch.ndim != 2:
                    print(f"Malformed val batch at epoch {epoch}, step {i}: shape={batch.shape}")
                
                inputs = batch[:, :-1]
                targets = batch[:, 1:]

                outputs, _ = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)

                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

                predicted = torch.argmax(outputs, dim=1)
                total_correct += (predicted == targets).sum().item()
                total_tokens += targets.numel()

        avg_val_loss = total_val_loss / len(val_loader)
        accuracy = total_correct / total_tokens
        duration = time.time() - start_time
        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Accuracy: {accuracy:.4f} | Time: {duration:.2f}s")

        # stop early if we haven't improved
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            epochs_without_improvement = 0
            # Optionally save the model
            torch.save(model.state_dict(), "best_model.pt")
        else:
            epochs_without_improvement += 1
        
        if epochs_without_improvement >= patience:
            print(f"Early stopping at epoch {epoch+1} due to no improvement in validation loss.")
            print(f"Best val_loss: {best_val_loss}")
            break

In [28]:
models = dict()
target_instruments = [80, 81, 38, 121]
# target_instruments = [80]
for target_instrument in target_instruments:
    try:
        print(f"TRAINING INSTRUMENT: {target_instrument}\n")
        MAX_SEQUENCE_LENGTH = 512
        tokenizer = CustomMIDITokenizer(target_instrument=target_instrument)
        pad_token_id = tokenizer.pad_token_id  # or hardcoded if you want
    
        train_dataset = MIDIDataset(midi_train_filepaths, tokenizer, MAX_SEQUENCE_LENGTH, train_midi_objs)
        train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: safe_collate(x, pad_token_id))
    
        test_dataset = MIDIDataset(midi_test_filepaths, tokenizer, MAX_SEQUENCE_LENGTH, test_midi_objs)
        test_loader = DataLoader(test_dataset, batch_size=4, shuffle=True, collate_fn=lambda x: safe_collate(x, pad_token_id))
    
        vocab_size = tokenizer.vocab_size
        embedding_dim = 256
        # hidden_dim = 512
        # hidden_dim = 256
        hidden_dim = 64 # matching the NES MDB paper
        num_layers = 2

        model = MusicRNN(vocab_size, embedding_dim, hidden_dim, num_layers)
        train(model, train_loader, test_loader, vocab_size)
        models[target_instrument] = model
        print("\n")
    except Exception as e:
        print(f"target {target_instrument} failed", e)
    

TRAINING INSTRUMENT: 80

training... |██████████████████████████████████████████████████| 100.0% Complete
validation... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 1/20 | Train Loss: 2.1544 | Val Loss: 1.4168 | Accuracy: 0.7068 | Time: 12.83s
training... |██████████████████████████████████████████████████| 100.0% Complete
validation... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 2/20 | Train Loss: 1.2961 | Val Loss: 1.0791 | Accuracy: 0.7567 | Time: 12.36s
training... |██████████████████████████████████████████████████| 100.0% Complete
validation... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 3/20 | Train Loss: 1.1014 | Val Loss: 1.0065 | Accuracy: 0.7611 | Time: 12.44s
training... |██████████████████████████████████████████████████| 100.0% Complete
validation... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 4/20 | Train Loss: 1.0334 | Val Loss: 0.9236 | Accuracy: 0

In [29]:
def save_model_weights(models):
    out_dir = os.path.join("weights", str(datetime.datetime.now()))
    os.mkdir(out_dir)
    for target_instrument, model in models.items():
        out_path = os.path.join(out_dir, f"instrument_{target_instrument}.pth")
        torch.save(model.state_dict(), out_path)
save_model_weights(models)

In [30]:
def sample(model, start_token, max_length=100, temperature=1.0, device='cuda'):
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)  # (1, 1)

    hidden = None

    for _ in range(max_length):
        output, hidden = model(input_token, hidden)  # output: (1, 1, vocab_size)
        output = output[:, -1, :]  # take the last output
        output = output / temperature  # adjust randomness

        probs = F.softmax(output, dim=-1)  # (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        if next_token == 2 or next_token == 0: # reach end of sequence
          break

        input_token = torch.tensor([[next_token]], device=device)

    return generated

In [78]:
def sample_models(models):
    out_dir = os.path.join("samples", str(datetime.datetime.now()))
    os.mkdir(out_dir)
    for target_instrument, model in models.items():
        tokenizer = CustomMIDITokenizer(target_instrument=target_instrument)
        start_token = tokenizer.special_tokens["BOS"]
        generated_sequence = sample(model, start_token, max_length=1024)
        out_path = os.path.join(out_dir, f"instrument_{target_instrument}.mid")
        torch.save(model.state_dict(), out_path)
        tokenizer.decode(generated_sequence, out_path)

sample_models(models)

In [114]:
start_token = tokenizer.special_tokens["BOS"]
generated_sequence = sample(model, start_token, max_length=1024 * 20, min_length=1024 * 10)
print(len(generated_sequence))
tokenizer.decode(generated_sequence, "rnn_2048_long.mid")

BREAK 10241
10241


In [79]:
import os
from miditoolkit import MidiFile

samples_dir = 'samples'

for folder_name in os.listdir(samples_dir):
    folder_path = os.path.join(samples_dir, folder_name)
    if not os.path.isdir(folder_path):
        continue

    combined_midi = MidiFile()
    instruments = []

    for file_name in os.listdir(folder_path):
        if not file_name.endswith('.mid'):
            continue

        midi_path = os.path.join(folder_path, file_name)
        midi = MidiFile(midi_path)
        instruments.extend(midi.instruments)

        # Use the first valid ticks_per_beat
        if combined_midi.ticks_per_beat == 480:
            combined_midi.ticks_per_beat = midi.ticks_per_beat

    combined_midi.instruments = instruments
    combined_path = os.path.join(folder_path, 'all.mid')
    combined_midi.dump(combined_path)