In [1]:
import math
from torch import nn
from torch.utils.data import DataLoader
from models import GRUModel

def train_epoch(model: nn.Module, loader: DataLoader, loss_fn, opt, device):
    model.train()
    total_loss, tokens = 0.0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        total_loss += loss.item() * y.numel()
        tokens += y.numel()
    return math.exp(total_loss / tokens) # perplexity


def evaluate(model: nn.Module, loader: DataLoader, loss_fn, device):
    model.eval()
    total_loss, tokens = 0.0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            loss = loss_fn(logits.reshape(-1, logits.size(-1)), y.reshape(-1))
            total_loss += loss.item() * y.numel()
            tokens += y.numel()
    return math.exp(total_loss / tokens)

In [2]:
import torch, json, os, csv
from torch.utils.data import random_split
from tokenizer import ABCTokenizer
from dataset   import LeadSheetDataset, collate_fn

# Read raw ABC file
with open("leadsheets.abc", "r") as f:
    raw_data = f.read()

with open("cache/abc_maps.json", "r", encoding="utf-8") as f:
    maps = json.load(f)

chord_map  = maps["chord_map"]
header_map = maps["header_map"]
inline_map = maps["inline_map"]

tunes = raw_data.strip().split("\n\n")
tokenizer = ABCTokenizer(chord_map=chord_map, header_map=header_map, inline_hdr_map=inline_map)
tokenizer.build_vocab(tunes)

dataset = LeadSheetDataset(tunes, tokenizer)
subset_size = int(0.2 * len(dataset))
subset, _ = random_split(dataset, [subset_size, len(dataset) - subset_size], generator=torch.Generator().manual_seed(42))
train_size = int(0.8 * len(subset))
val_size = len(subset) - train_size
train_ds, val_ds = random_split(subset, [train_size, val_size], generator=torch.Generator().manual_seed(42))

pad_idx = tokenizer.stoi[tokenizer.pad_token]
collate = lambda batch: collate_fn(batch, pad_idx)

In [None]:
import optuna

vocab_size = tokenizer.vocab_size()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def objective(trial):
    # Hyperparameters to explore
    num_layers = trial.suggest_int('num_layers', 1, 2)
    hidden_dim = trial.suggest_categorical('hidden_dim', [640, 768, 800])
    embed_dim = trial.suggest_categorical('embed_dim', [256, 320, 384])
    learning_rate = trial.suggest_float('lr', 5e-4, 1.2e-3, log=True)
    weight_decay = 1e-4
    if num_layers == 1:
        dropout = 0.0
    else:
        dropout = trial.suggest_float('dropout', 0.0, 0.3)

    batch_size = 16
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, collate_fn=collate)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, collate_fn=collate)

    # Model, Optimizer, Loss
    model = GRUModel(
        vocab_size=vocab_size,
        embed_dim=embed_dim,
        hidden_dim=hidden_dim,
        num_layers=num_layers,
        dropout=dropout,
        pad_idx=pad_idx
    ).to(device)

    loss_fn = nn.CrossEntropyLoss(ignore_index=pad_idx)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Train for a small number of epochs
    try:
        for epoch in range(4):
            train_epoch(model, train_loader, loss_fn, optimizer, device)
            val_ppl = evaluate(model, val_loader, loss_fn, device)

            # Report to Optuna
            trial.report(val_ppl, epoch)

            # Prune bad trials
            if trial.should_prune():
                raise optuna.TrialPruned()

        return val_ppl

    except RuntimeError as e:
        print("OOM error: pruning trial.")
        raise optuna.TrialPruned()

# -----------------------------------------------------------------------------------
# SETUP STUDY
pruner = optuna.pruners.MedianPruner(n_warmup_steps=2)
study = optuna.create_study(direction="minimize", study_name="gru_hyperparam_tuning", pruner=pruner)
study.optimize(objective, n_trials=25)

# -----------------------------------------------------------------------------------
# BEST PARAMETERS
print("Best hyperparameters:", study.best_params)
print("Best validation perplexity:", study.best_value)

In [25]:
import pandas as pd
import optuna.visualization

# Print best params
print("Best hyperparameters:", study.best_params)
print("Best validation perplexity:", study.best_value)

# Full DataFrame of all trials
df = study.trials_dataframe()
display(df)

# Optimization history plot
optuna.visualization.plot_optimization_history(study)

# Hyperparameter importance plot
optuna.visualization.plot_param_importances(study)

Best hyperparameters: {'num_layers': 2, 'hidden_dim': 800, 'embed_dim': 320, 'lr': 0.0009386283804842192, 'dropout': 0.12843562130373742}
Best validation perplexity: 2.801684043847251


Unnamed: 0,number,value,datetime_start,datetime_complete,duration,params_dropout,params_embed_dim,params_hidden_dim,params_lr,params_num_layers,state
0,0,2.916082,2025-04-26 22:06:58.243244,2025-04-26 22:08:47.350205,0 days 00:01:49.106961,,384,640,0.001094,1,COMPLETE
1,1,2.83353,2025-04-26 22:08:47.351074,2025-04-26 22:13:36.761703,0 days 00:04:49.410629,0.243941,320,800,0.000768,2,COMPLETE
2,2,2.866726,2025-04-26 22:13:36.762630,2025-04-26 22:15:49.879968,0 days 00:02:13.117338,,256,768,0.000931,1,COMPLETE
3,3,2.874386,2025-04-26 22:15:49.880850,2025-04-26 22:19:27.086270,0 days 00:03:37.205420,0.1727,256,640,0.000681,2,COMPLETE
4,4,2.827424,2025-04-26 22:19:27.087188,2025-04-26 22:24:00.548896,0 days 00:04:33.461708,0.039281,384,768,0.00075,2,COMPLETE
5,5,2.829983,2025-04-26 22:24:00.549870,2025-04-26 22:28:55.354307,0 days 00:04:54.804437,0.116808,384,800,0.001001,2,COMPLETE
6,6,2.985752,2025-04-26 22:28:55.355310,2025-04-26 22:31:39.992944,0 days 00:02:44.637634,0.092318,256,640,0.000624,2,PRUNED
7,7,2.95818,2025-04-26 22:31:39.993794,2025-04-26 22:35:06.822615,0 days 00:03:26.828821,0.037873,320,768,0.000589,2,PRUNED
8,8,2.813022,2025-04-26 22:35:06.823655,2025-04-26 22:39:40.443134,0 days 00:04:33.619479,0.274282,384,768,0.000838,2,COMPLETE
9,9,2.836137,2025-04-26 22:39:40.444073,2025-04-26 22:44:13.229032,0 days 00:04:32.784959,0.029345,320,768,0.000798,2,PRUNED


In [26]:
df.to_csv("optuna_trials2.csv", index=False)
import joblib
joblib.dump(study, "optuna_study2.pkl")

fig1 = optuna.visualization.plot_optimization_history(study)
fig1.write_html("optimization_history2.html")
fig2 = optuna.visualization.plot_param_importances(study)
fig2.write_html("param_importances2.html")