In [None]:
def train_with_hyperparams(embed_dim, hidden_dim, learning_rate, dropout, weight_decay, epochs=20, patience=5):
    model = MelodyModel(
        vocab_size=len(vocab),
        embed_dim=embed_dim,
        hidden_dim=hidden_dim,
        dropout=dropout
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    loss_fn = nn.CrossEntropyLoss()

    best_val_loss = float('inf')
    best_model_state = None
    patience_counter = 0

    for epoch in range(epochs):
        train_loss = train_one_epoch(model, train_loader, loss_fn, optimizer)
        val_loss = evaluate(model, val_loader, loss_fn)
        print(f"[Epoch {epoch+1}] Train: {train_loss:.4f} | Val: {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_state = model.state_dict()
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping triggered at epoch {epoch+1}. Best val loss: {best_val_loss:.4f}")
                break

    return train_loss, best_val_loss

In [None]:
import random

search_space = {
    "embed_dim": [64, 128],
    "hidden_dim": [256, 512],
    "learning_rate": [5e-4, 1e-3, 5e-3],
    "dropout": [0.4, 0.5],
    "weight_decay": [1e-4, 5e-4, 1e-5]
}

def sample_hyperparams(space):
    return {
        "embed_dim": random.choice(space["embed_dim"]),
        "hidden_dim": random.choice(space["hidden_dim"]),
        "learning_rate": random.choice(space["learning_rate"]),
        "dropout": random.choice(space["dropout"]),
        "weight_decay": random.choice(space["weight_decay"]),
    }

results = []

for trial in range(20):
    print(f"\nTrial {trial+1}")
    params = sample_hyperparams(search_space)
    print("Params:", params)

    train_loss, val_loss = train_with_hyperparams(**params)

    results.append({
        "Trial": trial + 1,
        **params,
        "Train Loss": train_loss,
        "Val Loss": val_loss
    })

In [None]:
sorted_results = sorted(results, key=lambda x: x["Val Loss"])
print("\nTop Results:")
for r in sorted_results[:5]:
    print(r)