In [1]:
from data_preprocessing import get_midis_by_composer, process_midis_to_text

composers = ["mozart", "haydn"]
midis = get_midis_by_composer(composers)

# [[train texts], [val texts], [test texts]]
midi_texts = [[],[],[]]

for i in range(len(midis)):
    midi_texts[i] = process_midis_to_text(midis[i])

Now loading MIDIs from data\train.
Could not load data\train\mozart-piano_sonatas-nueva_carpeta-k281_piano_sonata_n03_3mov.mid: Could not decode key with 2 flats and mode 2
Could not load data\train\unknown_artist-i_o-mozart_k550.mid: MThd not found. Probably not a MIDI file
Loaded 311 MIDI files from data\train
Now loading MIDIs from data\val.
Loaded 29 MIDI files from data\val
Now loading MIDIs from data\test.
Could not load data\test\unknown_artist-i_o-mozart_q1_2.mid: MThd not found. Probably not a MIDI file
Loaded 28 MIDI files from data\test
368 MIDI files retrieved.
Successfully processed 311 MIDIs into text.
Successfully processed 29 MIDIs into text.
Successfully processed 28 MIDIs into text.


In [2]:
from data_preprocessing import VocabBuilder
import torch

# Training sequences
training_texts = midi_texts[0]

# Build vocab from training data
vb = VocabBuilder(training_texts)
train_ids = vb. train_ids

# Encode validation and testing data texts
val_ids = torch.tensor([tok for seq in midi_texts[1] for tok in vb.encode(seq)], dtype=torch.long)
test_ids = torch.tensor([tok for seq in midi_texts[2] for tok in vb.encode(seq)], dtype=torch.long)

Vocabulary size (train only): 627


In [3]:
from models import MidiTextTransformer, train_midi_text_transformer, generate_midi_tokens_with_transformer
from data_preprocessing import SEQ_SOS, SEQ_EOS

vocab_size = vb.vocab_size

model = MidiTextTransformer(vocab_size=vocab_size, d_model=256, n_head=4, n_layer=6,
                          dim_ff=512, block_size=512)

trained_model = train_midi_text_transformer(
    model,
    train_ids=train_ids,
    val_ids=val_ids,
    vocab_size=vocab_size,
    max_iters=8000,
    eval_interval=500,
    lr=3e-4,
)

step 0: train loss 6.219, acc 0.022 | val loss 6.242, acc 0.018
step 500: train loss 1.776, acc 0.517 | val loss 1.959, acc 0.486
step 1000: train loss 1.502, acc 0.582 | val loss 1.695, acc 0.546
step 1500: train loss 1.369, acc 0.612 | val loss 1.682, acc 0.547
step 2000: train loss 1.339, acc 0.617 | val loss 1.588, acc 0.567
step 2500: train loss 1.290, acc 0.631 | val loss 1.501, acc 0.587
step 3000: train loss 1.301, acc 0.629 | val loss 1.486, acc 0.592
step 3500: train loss 1.222, acc 0.644 | val loss 1.508, acc 0.588
step 4000: train loss 1.229, acc 0.647 | val loss 1.464, acc 0.602
step 4500: train loss 1.209, acc 0.655 | val loss 1.416, acc 0.610
step 5000: train loss 1.194, acc 0.656 | val loss 1.454, acc 0.599
step 5500: train loss 1.191, acc 0.659 | val loss 1.413, acc 0.610
step 6000: train loss 1.133, acc 0.670 | val loss 1.468, acc 0.596
step 6500: train loss 1.135, acc 0.672 | val loss 1.440, acc 0.604
step 7000: train loss 1.136, acc 0.671 | val loss 1.432, acc 0.610

In [8]:
# IDs for special tokens
SOS_ID = vb.stoi[SEQ_SOS]
EOS_ID = vb.stoi[SEQ_EOS]

# Seed with first few tokens from the first piece in the testing set
seed_tokens = vb.encode(midi_texts[2][5])[:300]

generated_ids = generate_midi_tokens_with_transformer(
    model,
    sos_id=SOS_ID,
    eos_id=EOS_ID,
    start_tokens=seed_tokens,
    max_new_tokens=6000,
)

generated_text = vb.decode(generated_ids)

print("First 100 chars of generated text:\n")
print(generated_text[:100])

First 100 chars of generated text:

<SOS> COMPOSER_haydn KEY_G TIME_SIGNATURE_3/4 TEMPO_BPM_250 MEASURE BEAT BEAT BEAT MEASURE BEAT BEAT


In [10]:
from midi_conversion import text_to_midi
import os

mid = text_to_midi(generated_text)

# Create output directory if it doesn't exist
os.makedirs("generated", exist_ok=True)

# Save to path
output_path = os.path.join("generated", "mozart_output_on_expanded_training_data.mid")
mid.save(output_path)