In [None]:
from data_preprocessing import get_midis_by_composer, process_midis_to_text

composers = ["haydn", "mozart", "beethoven"]
midis = get_midis_by_composer(composers)

# [[train texts], [val texts], [test texts]]
midi_texts = [[],[],[]]

for i in range(len(midis)):
    midi_texts[i] = process_midis_to_text(midis[i])

Now loading MIDIs from data\train.
Could not load data\train\mozart-piano_sonatas-nueva_carpeta-k281_piano_sonata_n03_3mov.mid: Could not decode key with 2 flats and mode 2
Could not load data\train\unknown_artist-i_o-mozart_k550.mid: MThd not found. Probably not a MIDI file
Loaded 311 MIDI files from data\train
Now loading MIDIs from data\val.
Loaded 29 MIDI files from data\val
Now loading MIDIs from data\test.
Could not load data\test\unknown_artist-i_o-mozart_q1_2.mid: MThd not found. Probably not a MIDI file
Loaded 28 MIDI files from data\test
368 MIDI files retrieved.
Successfully processed 311 MIDIs into text.
Successfully processed 29 MIDIs into text.
Successfully processed 28 MIDIs into text.


In [20]:
from data_preprocessing import VocabBuilder
import torch

# Training sequences
training_texts = midi_texts[0]

# Build vocab from training data
vb = VocabBuilder(training_texts)
train_ids = vb. train_ids

# Encode validation and testing data texts
val_ids = torch.tensor([tok for seq in midi_texts[1] for tok in vb.encode(seq)], dtype=torch.long)
test_ids = torch.tensor([tok for seq in midi_texts[2] for tok in vb.encode(seq)], dtype=torch.long)

Vocabulary size (train only): 627


In [None]:
import mido

# Temporary code cell for saving text conversions of individual MIDI files
midi_obj = mido.MidiFile("data/test/mozart-symphonies-symphony_n28_k200_4mov.mid")
print(process_midis_to_text([(midi_obj, "mozart")]))
print(midi_texts[0][0])

Successfully processed 1 MIDIs into text.
['<SOS> COMPOSER_mozart KEY_C TIME_SIGNATURE_4/4 TEMPO_BPM_250 MEASURE BEAT POS_0 NOTE_60 DUR_20 VEL_2 NOTE_79 DUR_8 VEL_2 POS_8 NOTE_81 DUR_8 VEL_2 POS_16 NOTE_79 DUR_8 VEL_2 POS_24 NOTE_62 DUR_20 VEL_2 NOTE_77 DUR_12 VEL_2 POS_36 NOTE_79 DUR_12 VEL_2 BEAT POS_0 NOTE_64 DUR_20 VEL_2 POS_24 NOTE_60 DUR_20 VEL_2 BEAT POS_0 NOTE_62 DUR_20 VEL_2 NOTE_77 DUR_8 VEL_2 POS_8 NOTE_79 DUR_8 VEL_2 POS_16 NOTE_77 DUR_8 VEL_2 POS_24 NOTE_64 DUR_20 VEL_2 NOTE_76 DUR_12 VEL_2 POS_36 NOTE_77 DUR_12 VEL_2 BEAT POS_0 NOTE_65 DUR_20 VEL_2 POS_24 NOTE_62 DUR_20 VEL_2 MEASURE BEAT POS_0 NOTE_64 DUR_20 VEL_2 NOTE_76 DUR_8 VEL_2 POS_8 NOTE_77 DUR_8 VEL_2 POS_16 NOTE_76 DUR_8 VEL_2 POS_24 NOTE_65 DUR_20 VEL_2 NOTE_74 DUR_12 VEL_2 POS_36 NOTE_76 DUR_12 VEL_2 BEAT POS_0 NOTE_67 DUR_20 VEL_2 POS_24 NOTE_64 DUR_20 VEL_2 BEAT POS_0 NOTE_65 DUR_20 VEL_2 NOTE_74 DUR_8 VEL_2 POS_8 NOTE_76 DUR_8 VEL_2 POS_16 NOTE_74 DUR_8 VEL_2 POS_24 NOTE_67 DUR_20 VEL_2 NOTE_72 DUR_12 VEL_2

In [5]:
from models import MidiTextTransformer, train_midi_text_transformer, generate_midi_tokens_with_transformer
from data_preprocessing import SEQ_SOS, SEQ_EOS

vocab_size = vb.vocab_size

model = MidiTextTransformer(vocab_size=vocab_size, d_model=512, n_head=8, n_layer=8,
                          dim_ff=1024, block_size=1024)

trained_model = train_midi_text_transformer(
    model,
    train_ids=train_ids,
    val_ids=val_ids,
    vocab_size=vocab_size,
    max_iters=10000,
    eval_interval=500,
    lr=3e-4,
)

step 0: train loss 5.663, acc 0.075 | val loss 5.738, acc 0.073
step 500: train loss 1.514, acc 0.577 | val loss 1.714, acc 0.542
step 1000: train loss 1.329, acc 0.619 | val loss 1.575, acc 0.566
step 1500: train loss 1.246, acc 0.645 | val loss 1.487, acc 0.591
step 2000: train loss 1.206, acc 0.653 | val loss 1.438, acc 0.612
step 2500: train loss 1.136, acc 0.672 | val loss 1.435, acc 0.610
step 3000: train loss 1.075, acc 0.684 | val loss 1.437, acc 0.608
step 3500: train loss 1.074, acc 0.687 | val loss 1.339, acc 0.635
step 4000: train loss 1.087, acc 0.689 | val loss 1.310, acc 0.640
step 4500: train loss 1.047, acc 0.698 | val loss 1.315, acc 0.641
step 5000: train loss 1.077, acc 0.688 | val loss 1.282, acc 0.650
step 5500: train loss 0.960, acc 0.720 | val loss 1.324, acc 0.639
step 6000: train loss 0.984, acc 0.714 | val loss 1.340, acc 0.638
step 6500: train loss 0.960, acc 0.720 | val loss 1.282, acc 0.656
step 7000: train loss 0.927, acc 0.730 | val loss 1.293, acc 0.652

In [9]:
# IDs for special tokens
SOS_ID = vb.stoi[SEQ_SOS]
EOS_ID = vb.stoi[SEQ_EOS]

# Seed with first few tokens from the first piece in the testing set
seed_tokens = vb.encode(midi_texts[2][0])[:300]

generated_ids = generate_midi_tokens_with_transformer(
    model,
    sos_id=SOS_ID,
    eos_id=EOS_ID,
    start_tokens=seed_tokens,
    max_new_tokens=2000,
)

generated_text = vb.decode(generated_ids)

print("First 100 chars of generated text:\n")
print(generated_text[:100])

First 100 chars of generated text:

<SOS> COMPOSER_haydn KEY_F TIME_SIGNATURE_4/4 TEMPO_BPM_90 MEASURE BEAT POS_0 NOTE_41 DUR_22 VEL_5 N


In [10]:
from midi_conversion import text_to_midi
import os

mid = text_to_midi(generated_text)

# Create output directory if it doesn't exist
os.makedirs("generated", exist_ok=True)

# Save to path
output_path = os.path.join("generated", "mozart_output_on_expanded_training_data.mid")
mid.save(output_path)

In [14]:
import torch

# Saving the model
torch.save(model.state_dict(), "models/transformer/transformer_weights.pt")
# Save vocab components
torch.save({
    "stoi": vb.stoi,
    "itos": vb.itos,
    "vocab_size": vb.vocab_size,
    "train_ids": vb.train_ids,
}, "models/transformer/vocab.pt")

