In [26]:
import os

import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW

import importlib
import data_preprocessing, midi_conversion, model_helpers, models, text_processing

from transformers import (
    GPT2LMHeadModel,
    GPT2Config,
    get_linear_schedule_with_warmup
)

importlib.reload(data_preprocessing)
importlib.reload(midi_conversion)
importlib.reload(model_helpers)
importlib.reload(models)
importlib.reload(text_processing)


<module 'text_processing' from 'd:\\classical-music-generation-model\\text_processing.py'>

In [2]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

DATA_DIR = "data/midi_text_exports"
VOCAB_FILE = "data/midi_text_exports/midi_vocab.txt"
BLOCK_SIZE = 512
BATCH_SIZE = 24
NUM_EPOCHS = 5
LEARNING_RATE = 3e-4
WEIGHT_DECAY = 0.01
MODEL_SAVE_DIR = "midi_gpt2_model"
TOKENIZER_SAVE_DIR = "midi_gpt2_tokenizer"

from midi_conversion import text_to_midi

Using device: cuda


In [29]:
from data_preprocessing import get_midis_by_composer, midi_split_to_text_split

composers = ["mozart", "haydn", "beethoven"]
midis = get_midis_by_composer(composers)

# [[train texts], [val texts], [test texts]]
# Export dir: "data/midi_text_exports"
midi_texts = midi_split_to_text_split(midis, save_to_directory="data/midi_text_exports")


Now loading MIDIs from data\train.
Could not load data\train\beethoven-anhang_14_3.mid: Could not decode key with 3 flats and mode 255
Could not load data\train\mozart-piano_sonatas-nueva_carpeta-k281_piano_sonata_n03_3mov.mid: Could not decode key with 2 flats and mode 2
Could not load data\train\unknown_artist-i_o-mozart_k550.mid: MThd not found. Probably not a MIDI file
Loaded 500 MIDI files from data\train
Now loading MIDIs from data\val.
Loaded 47 MIDI files from data\val
Now loading MIDIs from data\test.
Could not load data\test\unknown_artist-i_o-mozart_q1_2.mid: MThd not found. Probably not a MIDI file
Loaded 43 MIDI files from data\test
590 MIDI files retrieved.
Successfully processed 500 MIDIs into text.
Successfully processed 47 MIDIs into text.
Successfully processed 43 MIDIs into text.
Saved 500 files to                       data/midi_text_exports\train
Saved 47 files to                       data/midi_text_exports\val
Saved 43 files to                       data/midi_tex

In [None]:
from text_processing import build_vocab_from_dir

if not os.path.exists(VOCAB_FILE):
    print(f"{VOCAB_FILE} not found, building from {DATA_DIR}...")
    counter = build_vocab_from_dir(DATA_DIR)
    base_tokens = sorted(counter.keys())
    specials = ["<pad>", "<bos>", "<eos>", "<unk>"]
    vocab = specials + base_tokens
    with open(VOCAB_FILE, "w", encoding="utf-8") as f:
        for tok in vocab:
            f.write(tok + "\n")
    print(f"Saved vocab with {len(vocab)} tokens to {VOCAB_FILE}")
else:
    print(f"Found existing vocab file: {VOCAB_FILE}")

Found existing vocab file: data/midi_text_exports/midi_vocab.txt
data/midi_text_exports/midi_vocab.txt not found, building from data/midi_text_exports...
Saved vocab with 640 tokens to data/midi_text_exports/midi_vocab.txt


In [25]:
from text_processing import MidiTokenizer

tokenizer = MidiTokenizer(VOCAB_FILE)
vocab_size = len(tokenizer.get_vocab())
print("MIDI vocab size:", vocab_size)


MIDI vocab size: 640


In [27]:
from text_processing import MidiTextDataset
from model_helpers import collate_fn


train_dataset = MidiTextDataset(os.path.join(DATA_DIR, "train"), tokenizer, block_size=BLOCK_SIZE)
val_dataset   = MidiTextDataset(os.path.join(DATA_DIR, "val"),   tokenizer, block_size=BLOCK_SIZE)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, tokenizer.pad_token_id),
)

Loaded 11156 sequences from data/midi_text_exports\train
Loaded 1193 sequences from data/midi_text_exports\val


In [28]:
print("Loading pretrained GPT-2...")
base_model_name = "gpt2"  # you can try "gpt2-medium" if you have VRAM

pretrained_model = GPT2LMHeadModel.from_pretrained(base_model_name)
base_config = pretrained_model.config

hidden_size = base_config.n_embd
print("GPT-2 hidden size:", hidden_size)

# New config: same architecture, new vocab size + pad/bos/eos
new_config = GPT2Config(
    vocab_size=vocab_size,
    n_positions=base_config.n_positions,
    n_ctx=base_config.n_ctx,
    n_embd=base_config.n_embd,
    n_layer=base_config.n_layer,
    n_head=base_config.n_head,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)

model = GPT2LMHeadModel(new_config)

# Copy transformer blocks and positional embeddings from pretrained
with torch.no_grad():
    # positional embeddings
    model.transformer.wpe.weight.copy_(pretrained_model.transformer.wpe.weight)

    # transformer blocks
    for new_block, old_block in zip(model.transformer.h, pretrained_model.transformer.h):
        new_block.load_state_dict(old_block.state_dict())

    # final layer norm
    model.transformer.ln_f.load_state_dict(pretrained_model.transformer.ln_f.state_dict())

    # We intentionally leave token embeddings (wte) and lm_head randomly initialized
    # to match new vocab.

model = model.to(DEVICE)
print("Model ready. New vocab size:", model.config.vocab_size)

Loading pretrained GPT-2...
GPT-2 hidden size: 768
Model ready. New vocab size: 640


In [None]:
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
total_steps = NUM_EPOCHS * len(train_loader)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=total_steps // 10,
    num_training_steps=total_steps,
)

best_val_loss = float("inf")

for epoch in range(NUM_EPOCHS):
    model.train()
    total_train_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=True)

    for batch in pbar:
        input_ids = batch["input_ids"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

        total_train_loss += loss.item()

        pbar.set_postfix({"batch_loss": f"{loss.item():.4f}"})

    avg_train_loss = total_train_loss / len(train_loader)

    # Validation
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            total_val_loss += outputs.loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} | train loss: {avg_train_loss:.4f} | val loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        print("  -> saving best model")
        os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
        model.save_pretrained(MODEL_SAVE_DIR)

# # Tokenizer saving (no longer necessary)
# os.makedirs(TOKENIZER_SAVE_DIR, exist_ok=True)
# tokenizer.save_pretrained(TOKENIZER_SAVE_DIR)


Epoch 1/5: 100%|██████████| 465/465 [05:08<00:00,  1.51it/s, batch_loss=0.9263]


Epoch 1/5 | train loss: 0.9619 | val loss: 1.0214
  -> saving best model


Epoch 2/5: 100%|██████████| 465/465 [05:10<00:00,  1.50it/s, batch_loss=0.9868]


Epoch 2/5 | train loss: 0.8682 | val loss: 0.9644
  -> saving best model


Epoch 3/5: 100%|██████████| 465/465 [05:10<00:00,  1.50it/s, batch_loss=0.7682]


Epoch 3/5 | train loss: 0.8006 | val loss: 0.9415
  -> saving best model


Epoch 4/5: 100%|██████████| 465/465 [05:07<00:00,  1.51it/s, batch_loss=0.9243]


Epoch 4/5 | train loss: 0.7530 | val loss: 0.9312
  -> saving best model


Epoch 5/5: 100%|██████████| 465/465 [05:07<00:00,  1.51it/s, batch_loss=0.7319]


Epoch 5/5 | train loss: 0.7184 | val loss: 0.9246
  -> saving best model


AttributeError: 'MidiTokenizer' object has no attribute 'save_pretrained'

In [12]:
from transformers import GPT2LMHeadModel
import torch, os

def generate_midi_tokens(
    prompt_text: str,
    max_new_tokens: int = 256,
    temperature: float = 1.0,
    top_k: int = 50,
):
    tok = MidiTokenizer(VOCAB_FILE)
    mdl = GPT2LMHeadModel.from_pretrained(MODEL_SAVE_DIR).to(DEVICE)
    mdl.eval()

    prompt_ids = tok.encode(prompt_text, add_special_tokens=False)

    # Build input: [BOS] + prompt tokens  (no EOS)
    input_ids = torch.tensor([[tok.bos_token_id] + prompt_ids], dtype=torch.long).to(DEVICE)
    attention_mask = (input_ids != tok.pad_token_id).long()

    with torch.no_grad():
        output_ids = mdl.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            pad_token_id=tok.pad_token_id,
            eos_token_id=tok.eos_token_id,
        )

    generated_ids = output_ids[0].tolist()
    generated_text = tok.decode(generated_ids, skip_special_tokens=True)
    return generated_text

# Example prompt
example_prompt = ""
generated_tokens = generate_midi_tokens(example_prompt, max_new_tokens=1024)
print("Generated token sequence:")
print(generated_tokens[:100], "..." if len(generated_tokens) > 500 else "")

# Convert generated text to MIDI and save
generated_mid = text_to_midi(generated_tokens)
os.makedirs("generated", exist_ok=True)
midi_path = os.path.join("generated", "gpt2_generated_sample.mid")
generated_mid.save(midi_path)
print("Saved generated MIDI to:", midi_path)


Generated token sequence:
<SOS> COMPOSER_mozart KEY_C TIME_SIGNATURE_4/4 TEMPO_BPM_112 MEASURE BEAT POS_0 NOTE_36 DUR_22 VEL_6 ...
Saved generated MIDI to: generated\gpt2_generated_sample.mid
