In [1]:
import math
from torch import nn
from torch.utils.data import DataLoader
from models import GRUModel

def train_epoch(model: nn.Module, loader: DataLoader, loss_fn, opt, device):
    model.train()
    total_loss, tokens = 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total_loss += loss.item() * y.numel()
        tokens += y.numel()
    return math.exp(total_loss / tokens) # perplexity


def evaluate(model: nn.Module, loader: DataLoader, loss_fn, device):
    model.eval()
    total_loss, tokens = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
            total_loss += loss.item() * y.numel()
            tokens += y.numel()
    return math.exp(total_loss / tokens)

In [2]:
import torch, json, os, csv
from torch.utils.data import random_split
from tokenizer import ABCTokenizer
from dataset   import LeadSheetDataset, collate_fn

# Read raw ABC file
with open("leadsheets.abc", "r") as f:
    raw_data = f.read()

with open("cache/abc_maps.json", "r", encoding="utf-8") as f:
    maps = json.load(f)

chord_map  = maps["chord_map"]
header_map = maps["header_map"]
inline_map = maps["inline_map"]

tunes = raw_data.strip().split("\n\n")
tokenizer = ABCTokenizer(chord_map=chord_map, header_map=header_map, inline_hdr_map=inline_map)
tokenizer.build_vocab(tunes)

dataset = LeadSheetDataset(tunes, tokenizer)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_ds, val_ds, test_ds = random_split(dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42))

pad_idx = tokenizer.stoi[tokenizer.pad_token]
collate = lambda batch: collate_fn(batch, pad_idx)

train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, collate_fn=collate)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, collate_fn=collate)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, collate_fn=collate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GRUModel(tokenizer.vocab_size(), pad_idx=pad_idx).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3)

SAVE_DIR = "./saved_models"
os.makedirs(SAVE_DIR, exist_ok=True)

logfile = os.path.join(SAVE_DIR, "train_log.csv")
with open(logfile, "w", newline='') as f:
    writer = csv.writer(f)
    writer.writerow(["epoch", "train_ppl", "val_ppl"])

In [None]:
best_val = float("inf")
for epoch in range(1, 16):
    train_ppl = train_epoch(model, train_loader, loss_fn, opt, device)
    val_ppl = evaluate(model, val_loader, loss_fn, device)
    print(f"Epoch {epoch:02d}: train ppl={train_ppl:.2f}  val ppl={val_ppl:.2f}")

    with open(logfile, "a", newline='') as f:
        writer = csv.writer(f)
        writer.writerow([epoch, train_ppl, val_ppl])

    if val_ppl < best_val:
        best_val = val_ppl
        ckpt_path = os.path.join(SAVE_DIR, f"model_epoch{epoch:02d}_ppl{val_ppl:.2f}.pt")
        
        # Save the full model
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
            'val_ppl': val_ppl,
        }, ckpt_path)
        print("  â†³ New best; model saved")