In [None]:
from dotenv import load_dotenv
import torch
from torch.utils.data import DataLoader
from miditok import MIDILike, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
from pathlib import Path
import random
from torch.optim import AdamW

# ----------------------
# Load environment variables
# ----------------------
load_dotenv()  # loads .env into os.environ

# ----------------------
# 1. CONFIGURATION
# ----------------------
MIDI_FOLDER = "../data/maestro"
MAX_SEQ_LEN = 1024
BATCH_SIZE = 8
EPOCHS = 20
LEARNING_RATE = 1e-4
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANDOM_SEED  = 42

# fix seed for reproducibility
torch.manual_seed(RANDOM_SEED)
random.seed(RANDOM_SEED)

# ----------------------
# 2. TOKENIZER SETUP
# ----------------------
config = TokenizerConfig() # set configuration for the tokenizer
tokenizer = MIDILike(config)
PAD_ID    = tokenizer.pad_token_id
BOS_ID    = tokenizer["BOS_None"]
EOS_ID    = tokenizer["EOS_None"]

# ----------------------
# 3. LOAD & SPLIT PATHS
# ----------------------
all_midi = [
    path for path in Path(MIDI_FOLDER).rglob("*")
    if path.suffix in (".mid", ".midi")
]
# shuffle paths before splitting
random.shuffle(all_midi)
n_val = int(0.1 * len(all_midi))
train_paths, val_paths = all_midi[n_val:], all_midi[:n_val]

# ----------------------
# 4. DATASETS & DATALOADERS
# ----------------------
train_ds = DatasetMIDI(
    files_paths=train_paths,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=BOS_ID,
    eos_token_id=EOS_ID
)
# val dataset
val_ds = DatasetMIDI(
    files_paths=val_paths,
    tokenizer=tokenizer,
    max_seq_len=MAX_SEQ_LEN,
    bos_token_id=BOS_ID,
    eos_token_id=EOS_ID
)

# collator that pads & copies inputs to labels
collator = DataCollator(PAD_ID, copy_inputs_as_labels=True)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,            # shuffling for train
    collate_fn=collator
)
val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,           # no need to shuffle validation
    collate_fn=collator
)

# ----------------------
# 5. WANDB INITIALIZATION
# ----------------------
import wandb
# login using API key from .env
wandb.login()

wandb.init(
    project="piano-transformer",
    config={
        "max_seq_len": MAX_SEQ_LEN,
        "batch_size": BATCH_SIZE,
        "epochs": EPOCHS,
        "learning_rate": LEARNING_RATE,
        "model_arch": "gpt2-small"
    }
)
wandb_config = wandb.config

# ----------------------
# 6. MODEL & OPTIMIZER
# ----------------------

from transformers import GPT2Config, GPT2LMHeadModel

# Minimal GPT-2–style model
hf_config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=MAX_SEQ_LEN,
    n_ctx=MAX_SEQ_LEN,
    n_embd=512,
    n_layer=6,
    n_head=8,
    loss_type=None
)
model     = GPT2LMHeadModel(hf_config).to(DEVICE)
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)


# ----------------------
# 7. TRAINING LOOP
# ----------------------
from fastprogress import master_bar, progress_bar

mb = master_bar(range(wandb_config.epochs))
for epoch in mb:
    model.train()
    for batch in progress_bar(train_loader, parent=mb):
        # move to device
        batch = {k: v.to(DEVICE) for k, v in batch.items()}

        # mask padding in labels (so that the cross-entropy loss ignores them)
        batch["labels"][batch["labels"] == PAD_ID] = -100

        outputs = model(**batch)
        loss    = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        # log batch loss
        wandb.log({"train/loss": loss.item()})
        mb.child.comment = f"loss: {loss.item():.4f}"
        
    # validation pass
    model.eval()
    total_val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            batch["labels"][batch["labels"] == PAD_ID] = -100
            total_val_loss += model(**batch).loss.item()

    avg_val_loss = total_val_loss / len(val_loader)
    mb.write(f"Epoch {epoch+1} — val_loss: {avg_val_loss:.4f}")
    wandb.log({"val/loss": avg_val_loss, "epoch": epoch+1})
    
# finish wandb run
wandb.finish()