In [1]:
import os

import torch
from torch.utils.data import DataLoader

import importlib
import data_preprocessing, midi_conversion, model_helpers, models, text_processing

from transformers import (
    GPT2LMHeadModel,
    GPT2Config
)

importlib.reload(data_preprocessing)
importlib.reload(midi_conversion)
importlib.reload(model_helpers)
importlib.reload(models)
importlib.reload(text_processing)


<module 'text_processing' from 'd:\\classical-music-generation-model\\text_processing.py'>

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

DATA_DIR = "data/midi_text_exports"
VOCAB_FILE = "data/midi_text_exports/midi_vocab.txt"
BLOCK_SIZE = 512
BATCH_SIZE = 24
NUM_EPOCHS = 20
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.01
MODEL_SAVE_DIR = "models/midi_gpt2_model"

Using device: cuda


In [8]:
from data_preprocessing import get_midis_by_composer, midi_split_to_text_split

composers = ["mozart", "haydn", "beethoven"]
midis = get_midis_by_composer(composers)

# [[train texts], [val texts], [test texts]]
# Export dir: "data/midi_text_exports"
midi_texts = midi_split_to_text_split(midis, save_to_directory="data/midi_text_exports")


Now loading MIDIs from data\train.
Could not load data\train\beethoven-anhang_14_3.mid: Could not decode key with 3 flats and mode 255
Could not load data\train\mozart-piano_sonatas-nueva_carpeta-k281_piano_sonata_n03_3mov.mid: Could not decode key with 2 flats and mode 2
Could not load data\train\unknown_artist-i_o-mozart_k550.mid: MThd not found. Probably not a MIDI file
Loaded 500 MIDI files from data\train
Now loading MIDIs from data\val.
Loaded 47 MIDI files from data\val
Now loading MIDIs from data\test.
Could not load data\test\unknown_artist-i_o-mozart_q1_2.mid: MThd not found. Probably not a MIDI file
Loaded 43 MIDI files from data\test
590 MIDI files retrieved.
Successfully processed 500 MIDIs into text.
Successfully processed 47 MIDIs into text.
Successfully processed 43 MIDIs into text.
Saved 500 files to                       data/midi_text_exports\train
Saved 47 files to                       data/midi_text_exports\val
Saved 43 files to                       data/midi_tex

In [3]:
from text_processing import build_vocab_from_dir
from text_processing import MidiTokenizer


if not os.path.exists(VOCAB_FILE):
    print(f"{VOCAB_FILE} not found, building from {DATA_DIR}...")
    counter = build_vocab_from_dir(DATA_DIR)
    base_tokens = sorted(counter.keys())
    specials = ["<pad>", "<bos>", "<eos>", "<unk>"]
    vocab = specials + base_tokens
    with open(VOCAB_FILE, "w", encoding="utf-8") as f:
        for tok in vocab:
            f.write(tok + "\n")
    print(f"Saved vocab with {len(vocab)} tokens to {VOCAB_FILE}")
else:
    print(f"Found existing vocab file: {VOCAB_FILE}")

tokenizer = MidiTokenizer(VOCAB_FILE)
vocab_size = len(tokenizer.get_vocab())
print("MIDI vocab size:", vocab_size)

Found existing vocab file: data/midi_text_exports/midi_vocab.txt
MIDI vocab size: 701


In [4]:
from text_processing import MidiTextDataset
from model_helpers import collate_fn


train_dataset = MidiTextDataset(os.path.join(DATA_DIR, "train"), tokenizer, block_size=BLOCK_SIZE)
val_dataset   = MidiTextDataset(os.path.join(DATA_DIR, "val"),   tokenizer, block_size=BLOCK_SIZE)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
)

Loaded 21696 sequences from data/midi_text_exports\train
Loaded 1962 sequences from data/midi_text_exports\val


In [7]:
print("Loading pretrained GPT-2...")
base_model_name = "gpt2"

pretrained_model = GPT2LMHeadModel.from_pretrained(base_model_name)
base_config = pretrained_model.config

hidden_size = base_config.n_embd
print("GPT-2 hidden size:", hidden_size)

config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=base_config.n_positions,
    n_ctx=base_config.n_ctx,
    n_embd=base_config.n_embd,
    n_layer=base_config.n_layer,
    n_head=base_config.n_head,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
    resid_pdrop=0.2,  # for mitigating overfitting
    embd_pdrop=0.2,
    attn_pdrop=0.2,
)

model = GPT2LMHeadModel(config)

# Copy transformer blocks and positional embeddings from pretrained
with torch.no_grad():
    # positional embeddings
    model.transformer.wpe.weight.copy_(pretrained_model.transformer.wpe.weight)

    # transformer blocks
    for new_block, old_block in zip(model.transformer.h, pretrained_model.transformer.h):
        new_block.load_state_dict(old_block.state_dict())

    # final layer norm
    model.transformer.ln_f.load_state_dict(pretrained_model.transformer.ln_f.state_dict())

    # We intentionally leave token embeddings (wte) and lm_head randomly initialized
    # to match new vocab.

model = model.to(DEVICE)
print("Model ready. New vocab size:", model.config.vocab_size)

Loading pretrained GPT-2...
GPT-2 hidden size: 768
Model ready. New vocab size: 701


In [8]:
from models import train_gpt_2

train_gpt_2(model, train_loader, val_loader, num_epochs=NUM_EPOCHS, lr=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY, device=DEVICE, model_save_dir=MODEL_SAVE_DIR)


Epoch 1/20:   0%|          | 0/904 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
Epoch 1/20: 100%|██████████| 904/904 [09:55<00:00,  1.52it/s, batch_loss=1.3197]


Epoch 1/20 | train loss:               2.9362 | val loss: 1.1838
  -> saving best model


Epoch 2/20: 100%|██████████| 904/904 [09:57<00:00,  1.51it/s, batch_loss=0.9003]


Epoch 2/20 | train loss:               1.0451 | val loss: 0.9974
  -> saving best model


Epoch 3/20: 100%|██████████| 904/904 [09:56<00:00,  1.51it/s, batch_loss=0.9492]


Epoch 3/20 | train loss:               0.8934 | val loss: 0.9128
  -> saving best model


Epoch 4/20: 100%|██████████| 904/904 [09:49<00:00,  1.53it/s, batch_loss=0.7193]


Epoch 4/20 | train loss:               0.8196 | val loss: 0.8757
  -> saving best model


Epoch 5/20: 100%|██████████| 904/904 [09:53<00:00,  1.52it/s, batch_loss=0.7589]


Epoch 5/20 | train loss:               0.7721 | val loss: 0.8528
  -> saving best model


Epoch 6/20: 100%|██████████| 904/904 [09:46<00:00,  1.54it/s, batch_loss=0.6977]


Epoch 6/20 | train loss:               0.7367 | val loss: 0.8412
  -> saving best model


Epoch 7/20: 100%|██████████| 904/904 [09:54<00:00,  1.52it/s, batch_loss=0.7371]


Epoch 7/20 | train loss:               0.7077 | val loss: 0.8269
  -> saving best model


Epoch 8/20: 100%|██████████| 904/904 [09:46<00:00,  1.54it/s, batch_loss=0.7448]


Epoch 8/20 | train loss:               0.6821 | val loss: 0.8296
No improvement in val loss for 1 epoch(s).


Epoch 9/20: 100%|██████████| 904/904 [09:47<00:00,  1.54it/s, batch_loss=0.6226]


Epoch 9/20 | train loss:               0.6592 | val loss: 0.8240
  -> saving best model


Epoch 10/20: 100%|██████████| 904/904 [09:52<00:00,  1.52it/s, batch_loss=0.6080]


Epoch 10/20 | train loss:               0.6385 | val loss: 0.8264
No improvement in val loss for 1 epoch(s).


Epoch 11/20: 100%|██████████| 904/904 [09:50<00:00,  1.53it/s, batch_loss=0.4933]


Epoch 11/20 | train loss:               0.6199 | val loss: 0.8190
  -> saving best model


Epoch 12/20: 100%|██████████| 904/904 [09:51<00:00,  1.53it/s, batch_loss=0.5760]


Epoch 12/20 | train loss:               0.6023 | val loss: 0.8194
No improvement in val loss for 1 epoch(s).


Epoch 13/20: 100%|██████████| 904/904 [09:49<00:00,  1.53it/s, batch_loss=0.6017]


Epoch 13/20 | train loss:               0.5869 | val loss: 0.8296
No improvement in val loss for 2 epoch(s).


Epoch 14/20: 100%|██████████| 904/904 [09:53<00:00,  1.52it/s, batch_loss=0.5280]


Epoch 14/20 | train loss:               0.5724 | val loss: 0.8302
No improvement in val loss for 3 epoch(s).
Early stopping triggered: no val-loss improvement for 3                         epochs.
Training complete. Best val loss: 0.8190 at epoch 11.


{'train_losses': [2.9362449400720343,
  1.0451273052433951,
  0.8933934326720449,
  0.8196259317672359,
  0.7721305268000712,
  0.7367371545560593,
  0.7077136263251305,
  0.6820601439515573,
  0.6592394471234453,
  0.6385357220352224,
  0.6198714019780138,
  0.6023377043285728,
  0.5868506243996388,
  0.5723666941038276],
 'val_losses': [1.1838221390072892,
  0.9973697291641701,
  0.9127541882235829,
  0.8757176675447603,
  0.8527709754501901,
  0.8411976560586836,
  0.8269084248600936,
  0.8296223806171883,
  0.8240174240455395,
  0.8264162758501564,
  0.8190460681188397,
  0.8194475988062416,
  0.829629559705897,
  0.8302457176330613],
 'best_val_loss': 0.8190460681188397,
 'best_epoch': 10}

In [9]:
from models import generate_midi_tokens_with_gpt_model
from midi_conversion import text_to_midi
import util

# Example generation
prompt = ""  # must decrease max_new_tokens if adding prompt
generated_tokens = generate_midi_tokens_with_gpt_model(
    prompt, VOCAB_FILE, MODEL_SAVE_DIR, max_new_tokens=1024, temp=1.0)
print("Generated token sequence:")
print(generated_tokens[:100], "..." if len(generated_tokens) > 500 else "")

# Convert generated text to MIDI and save
generated_mid = text_to_midi(generated_tokens)
util.mkdir("generated")
midi_path = util.path_join("generated", "gpt2_generated_sample.mid")
generated_mid.save(midi_path)
print("Saved generated MIDI to:", midi_path)


Generated token sequence:
<SOS> COMPOSER_beethoven KEY_F TIME_SIGNATURE_6/8 TEMPO_BPM_36 MEASURE BEAT POS_0 NOTE_60 DUR_24 VEL ...
Saved generated MIDI to: generated\gpt2_generated_sample.mid
