In [44]:
import pretty_midi
import os
from collections import defaultdict
from torch.utils.data import Dataset
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import time
from miditok.pytorch_data import DatasetMIDI, DataCollator

  from .autonotebook import tqdm as notebook_tqdm


In [53]:
def print_progress_bar(iteration, total, prefix='', length=50):
    percent = ("{0:.1f}").format(100 * (iteration / float(total)))
    filled_length = int(length * iteration // total)
    bar = '█' * filled_length + '-' * (length - filled_length)
    print(f'\r{prefix} |{bar}| {percent}% Complete', end='\r', flush=True)
    if iteration == total:
        print()

In [55]:
def valid_midi_files(filepaths):
    valid_files = []
    for i, filepath in enumerate(filepaths):
        try:
            midi_data = pretty_midi.PrettyMIDI(filepath)
            if len(midi_data.instruments) > 0:
                valid_files.append(filepath)
        except Exception as e:
            print(f"Error processing {filepath}: {e}")
        print_progress_bar(i+1, len(filepaths), prefix='Validating MIDI files')
    return valid_files

In [56]:
midi_dirpath = 'nesmdb_midi/'
midi_train_dirpath = os.path.join(midi_dirpath, 'train')
midi_test_dirpath = os.path.join(midi_dirpath, 'test')
midi_val_dirpath = os.path.join(midi_dirpath, 'valid')
midi_train_filesnames = os.listdir(midi_train_dirpath)
midi_test_filesnames = os.listdir(midi_test_dirpath)
midi_val_filenames = os.listdir(midi_val_dirpath)

midi_train_filepaths = valid_midi_files([os.path.join(midi_train_dirpath, filename) for filename in midi_train_filesnames])
midi_test_filepaths = valid_midi_files([os.path.join(midi_test_dirpath, filename) for filename in midi_test_filesnames])
midi_val_filepaths = valid_midi_files([os.path.join(midi_val_dirpath, filename) for filename in midi_val_filenames])
all_filepaths = midi_train_filepaths + midi_test_filepaths + midi_val_filepaths

Error processing nesmdb_midi/train/122_FireEmblem_AnkokuRyutoHikarinoTsurugi_30_31EndingOmnibus.mid: MIDI file has a largest tick of 13007350, it is likely corrupt
Error processing nesmdb_midi/train/215_Magician_15_16EpiloguePart1.mid: MIDI file has a largest tick of 24797107, it is likely corrupt
Error processing nesmdb_midi/train/298_SolarJetman_HuntfortheGoldenWarpship_18_19LemonteGameplay.mid: MIDI file has a largest tick of 12682092, it is likely corrupt
Error processing nesmdb_midi/train/122_FireEmblem_AnkokuRyutoHikarinoTsurugi_28_29EndingOmnibusBallad.mid: MIDI file has a largest tick of 17014635, it is likely corrupt
Error processing nesmdb_midi/train/215_Magician_08_09MountVunarCavernsAbadonsCastle.mid: MIDI file has a largest tick of 16907305, it is likely corrupt
Error processing nesmdb_midi/train/405_ZombieNation_03_04VergeofDangerRoundSelect.mid: MIDI file has a largest tick of 18033915, it is likely corrupt
Error processing nesmdb_midi/train/104_FamicomJumpII_Saikyono7_n

In [57]:
TIME_SHIFT_RESOLUTION = 0.01  # 50 ms
MAX_SHIFT_STEPS = 100  # Max 5 seconds
BEGINNING_OF_SONG_TOKEN = '<BOS>'
END_OF_SONG_TOKEN = '<EOS>'
PAD_TOKEN = '<PAD>'
VOCABULARY = dict()
index = 0

for special_token in [PAD_TOKEN, BEGINNING_OF_SONG_TOKEN, END_OF_SONG_TOKEN]:
    VOCABULARY[special_token] = index
    index += 1

for time_shift in range(1, MAX_SHIFT_STEPS + 1):
    VOCABULARY[f'time_shift_{time_shift}'] = index
    index += 1

for action in ["note_on", "note_off"]:
    for pitch in range(128):
        for program in [80, 81, 38, 121]:
            VOCABULARY[f'{action}_{pitch}_instrument_{program}'] = index
            index += 1

print(f'Vocabulary size: {len(VOCABULARY)}')
print(f'Vocabulary: {VOCABULARY}')

Vocabulary size: 1127
Vocabulary: {'<PAD>': 0, '<BOS>': 1, '<EOS>': 2, 'time_shift_1': 3, 'time_shift_2': 4, 'time_shift_3': 5, 'time_shift_4': 6, 'time_shift_5': 7, 'time_shift_6': 8, 'time_shift_7': 9, 'time_shift_8': 10, 'time_shift_9': 11, 'time_shift_10': 12, 'time_shift_11': 13, 'time_shift_12': 14, 'time_shift_13': 15, 'time_shift_14': 16, 'time_shift_15': 17, 'time_shift_16': 18, 'time_shift_17': 19, 'time_shift_18': 20, 'time_shift_19': 21, 'time_shift_20': 22, 'time_shift_21': 23, 'time_shift_22': 24, 'time_shift_23': 25, 'time_shift_24': 26, 'time_shift_25': 27, 'time_shift_26': 28, 'time_shift_27': 29, 'time_shift_28': 30, 'time_shift_29': 31, 'time_shift_30': 32, 'time_shift_31': 33, 'time_shift_32': 34, 'time_shift_33': 35, 'time_shift_34': 36, 'time_shift_35': 37, 'time_shift_36': 38, 'time_shift_37': 39, 'time_shift_38': 40, 'time_shift_39': 41, 'time_shift_40': 42, 'time_shift_41': 43, 'time_shift_42': 44, 'time_shift_43': 45, 'time_shift_44': 46, 'time_shift_45': 47, 

In [59]:
ID_TO_TOKEN = {v: k for k, v in VOCABULARY.items()}

In [82]:
def midi_to_tokens(pm: pretty_midi.PrettyMIDI):
    events = []

    for instrument in pm.instruments:
        for note in instrument.notes:
            events.append((note.start, f'note_on_{note.pitch}_instrument_{instrument.program}'))
            events.append((note.end, f'note_off_{note.pitch}_instrument_{instrument.program}'))
    
    events.sort()  # Sort by time

    tokens = []
    last_time = 0.0
    for time, event in events:
        delta = time - last_time
        steps = round(delta / TIME_SHIFT_RESOLUTION)

        while steps > 0:
            shift = min(steps, MAX_SHIFT_STEPS)
            tokens.append(f'time_shift_{shift}')
            steps -= shift
        
        tokens.append(event)
        last_time = time
    return [BEGINNING_OF_SONG_TOKEN] + tokens + [END_OF_SONG_TOKEN]

def tokens_to_midi(tokens):
    pm = pretty_midi.PrettyMIDI()
    instruments = dict()
    active_notes = dict()

    current_time = 0.0
    for token in tokens:
        if token.startswith('time_shift_'):
            shift_steps = int(token.split('_')[-1])
            current_time += shift_steps * TIME_SHIFT_RESOLUTION
        elif token.startswith('note_on_'):
            pitch = int(token.split('_')[2])
            instrument = int(token.split('_')[-1])
            active_notes[(pitch, instrument)] = current_time
        elif token.startswith('note_off_'):
            pitch = int(token.split('_')[2])
            instrument = int(token.split('_')[-1])
            if (pitch, instrument) not in active_notes:
                print(f"Warning: Note off for {pitch} on instrument {instrument} without matching note on.")
                continue
            start_time = active_notes[(pitch, instrument)]

            if instrument not in instruments:
                instruments[instrument] = pretty_midi.Instrument(program=instrument)

            note = pretty_midi.Note(
                velocity=100, pitch=pitch, start=start_time, end=current_time
            )

            instruments[instrument].notes.append(note)
            
    for instrument in instruments.values():
        pm.instruments.append(instrument)
    
    return pm

In [15]:
midi = pretty_midi.PrettyMIDI(all_filepaths[0])
tokens = midi_to_tokens(midi)
print(tokens[:30])
midi_reconstructed = tokens_to_midi(tokens)
# midi_reconstructed.write('reconstructed_midi.mid')
# midi.write('original_midi.mid')

['<BOS>', 'note_on_70_instrument_81', 'note_on_62_instrument_80', 'note_on_58_instrument_38', 'time_shift_37', 'note_off_58_instrument_38', 'time_shift_1', 'note_off_70_instrument_81', 'note_off_62_instrument_80', 'time_shift_2', 'note_on_70_instrument_81', 'note_on_62_instrument_80', 'note_on_58_instrument_38', 'time_shift_7', 'note_off_58_instrument_38', 'time_shift_1', 'note_off_70_instrument_81', 'note_off_62_instrument_80', 'time_shift_2', 'note_on_70_instrument_81', 'note_on_62_instrument_80', 'note_on_58_instrument_38', 'time_shift_7', 'note_off_58_instrument_38', 'time_shift_1', 'note_off_70_instrument_81', 'note_off_62_instrument_80', 'time_shift_2', 'note_on_70_instrument_81', 'note_on_62_instrument_80']


In [74]:
def load_sequences(filepaths):
    sequences = []
    for i, filepath in enumerate(filepaths):
        pm = pretty_midi.PrettyMIDI(filepath)
        tokens = midi_to_tokens(pm)
        if not tokens:
            raise ValueError(f'No tokens generated for {filepath}')
        sequences.append([VOCABULARY[token] for token in tokens])
        print_progress_bar(i+1, len(filepaths), prefix='Loading sequences')
    return sequences

class MIDITokenDataset(Dataset):
    def __init__(self, sequences, seq_length=512):
        self.inputs = []
        self.targets = []
        for ind, seq in enumerate(sequences):
            num_chunks = len(seq) // (seq_length + 1)
            for chunk in range(num_chunks + 1):
                chunk_start = chunk * (seq_length + 1)
                chunk_end = chunk_start + seq_length + 1
                chunk = seq[chunk_start:chunk_end]

                if len(chunk) < seq_length / 4:
                    continue

                input = chunk[:-1]
                input = np.pad(input, (0, seq_length - len(input)), constant_values=VOCABULARY[PAD_TOKEN])
                
                target = chunk[1:]
                target = np.pad(target, (0, seq_length - len(target)), constant_values=VOCABULARY[PAD_TOKEN])
                
                self.inputs.append(torch.tensor(input, dtype=torch.long))
                self.targets.append(torch.tensor(target, dtype=torch.long))
            
            print_progress_bar(ind+1, len(sequences), prefix='Processing sequences')

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

In [61]:
# Load and convert
train_sequences = load_sequences(midi_train_filepaths)
val_sequences = load_sequences(midi_val_filepaths)

Loading sequences |██████████████████████████████████████████████████| 100.0% Complete
Loading sequences |██████████████████████████████████████████████████| 100.0% Complete


In [62]:

def print_sequence_percentiles(sequences, prefix=''):
    print(f'{prefix} sequences: {len(sequences)}')
    print(f"{prefix} 90th percile length: {np.percentile([len(seq) for seq in sequences], 90)}")
    print(f"{prefix} 50th percile length: {np.percentile([len(seq) for seq in sequences], 50)}")
    print(f"{prefix} 25th percile length: {np.percentile([len(seq) for seq in sequences], 25)}")
    print(f"{prefix} 10th percile length: {np.percentile([len(seq) for seq in sequences], 10)}")
    print(f"{prefix} max train sequence length: {max(len(seq) for seq in sequences)}")
    print(f"{prefix} min train sequence length: {min(len(seq) for seq in sequences)}")

print_sequence_percentiles(train_sequences, prefix='Train')
print_sequence_percentiles(val_sequences, prefix='Validation')

Train sequences: 4470
Train 90th percile length: 4361.699999999999
Train 50th percile length: 944.0
Train 25th percile length: 286.25
Train 10th percile length: 99.0
Train max train sequence length: 27506
Train min train sequence length: 7
Validation sequences: 402
Validation 90th percile length: 4329.300000000001
Validation 50th percile length: 682.0
Validation 25th percile length: 176.0
Validation 10th percile length: 73.0
Validation max train sequence length: 14466
Validation min train sequence length: 14


In [78]:
# Create datasets
train_dataset = MIDITokenDataset(train_sequences[:4], seq_length=512)
val_dataset = MIDITokenDataset(val_sequences[:1], seq_length=512)
print(f'Train dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(val_dataset)}')

Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Processing sequences |██████████████████████████████████████████████████| 100.0% Complete
Train dataset size: 7
Validation dataset size: 5


In [38]:
class MusicRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(MusicRNN, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.rnn = nn.LSTM(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x, hidden=None):
        # x: (batch_size, seq_length)
        x = self.embedding(x)  # (batch_size, seq_length, embedding_dim)
        out, hidden = self.rnn(x, hidden)  # out: (batch_size, seq_length, hidden_dim)
        out = self.fc(out)  # (batch_size, seq_length, vocab_size)
        return out, hidden

In [None]:
def train(model, train_loader, val_loader, vocab_size, num_epochs=20, lr=0.001, device='cuda'):
    time_start = time.time()
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        # --------- Training ---------
        model.train()
        total_train_loss = 0

        for i, batch in enumerate(train_loader):
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)

            # batch = batch['input_ids'].to(device)  # (batch_size, seq_length)

            # inputs = batch[:, :-1]
            # targets = batch[:, 1:]

            optimizer.zero_grad()
            outputs, _ = model(inputs)

            outputs = outputs.reshape(-1, vocab_size)
            targets = targets.reshape(-1)

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            print_progress_bar(i+1, len(train_loader), prefix='Training...')

        avg_train_loss = total_train_loss / len(train_loader)

        # --------- Validation ---------
        model.eval()
        total_val_loss = 0
        with torch.no_grad():
            for i, batch in enumerate(val_loader):
                inputs = inputs.to(device)
                targets = targets.to(device)
                # batch = batch['input_ids'].to(device)

                # inputs = batch[:, :-1]
                # targets = batch[:, 1:]

                outputs, _ = model(inputs)
                outputs = outputs.reshape(-1, vocab_size)
                targets = targets.reshape(-1)

                loss = criterion(outputs, targets)
                total_val_loss += loss.item()

                print_progress_bar(i+1, len(val_loader), prefix='Validating...')

        avg_val_loss = total_val_loss / len(val_loader)

        print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")


In [79]:
batch_size = 8  # You can adjust this depending on your GPU memory
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = MusicRNN(vocab_size=len(VOCABULARY), embedding_dim=256, hidden_dim=512, num_layers=2)
train(model, train_loader, val_loader, vocab_size=len(VOCABULARY), device=device)

Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 1/20 | Train Loss: 7.0312 | Val Loss: 6.8262
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 2/20 | Train Loss: 6.8262 | Val Loss: 6.4490
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 3/20 | Train Loss: 6.4490 | Val Loss: 6.0285
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |██████████████████████████████████████████████████| 100.0% Complete
Epoch 4/20 | Train Loss: 6.0285 | Val Loss: 5.7496
Training... |██████████████████████████████████████████████████| 100.0% Complete
Validating... |████████████████████████████████████████████

In [None]:
def sample(model, start_token, max_length=100, temperature=1.0, device='cuda'):
    model = model.to(device)
    model.eval()

    generated = [start_token]
    input_token = torch.tensor([[start_token]], device=device)  # (1, 1)

    hidden = None

    for _ in range(max_length):
        output, hidden = model(input_token, hidden)  # output: (1, 1, vocab_size)
        output = output[:, -1, :]  # take the last output
        output = output / temperature  # adjust randomness

        probs = F.softmax(output, dim=-1)  # (1, vocab_size)
        next_token = torch.multinomial(probs, num_samples=1).item()
        generated.append(next_token)
        if next_token == VOCABULARY[END_OF_SONG_TOKEN] or VOCABULARY[PAD_TOKEN]: # reach end of sequence
          break

        input_token = torch.tensor([[next_token]], device=device)

    return generated

start_token = VOCABULARY[BEGINNING_OF_SONG_TOKEN]
device = 'cuda' if torch.cuda.is_available() else 'cpu'
generated_sequence = sample(model, start_token, max_length=1024, device=device)
generated_tokens = [ID_TO_TOKEN[token] for token in generated_sequence]
tokens_reconstructed = tokens_to_midi(generated_tokens)
tokens_reconstructed.write('generated_midi.mid')



### TRAIN SOLOISTS