In [None]:
import itertools
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split

import numpy as np
# Your model



# --- Utility: Compute MSE for a whole dataset ---
def dataset_mse(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    n_samples = 0
    with torch.no_grad():
        for holds, globals, targets in loader:
            holds, globals, targets = holds.to(device), globals.to(device), targets.to(device)
            preds = model(holds, globals).squeeze()
            loss = loss_fn(preds, targets)
            total_loss += loss.item() * len(targets)
            n_samples += len(targets)
    return total_loss / n_samples


# --- Training function (used for each hyperparam config) ---
def train_model(dataset, num_epochs=500, lr=1e-3, weight_decay=1e-3,
                hidden_dim=256, dropout=0.4, patience=50, batch_size=64, verbose=False):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    loss_fn = nn.MSELoss()

    # Split dataset
    n_total = len(dataset)
    n_train = int(0.8 * n_total)
    n_val = int(0.1 * n_total)
    n_test = n_total - n_train - n_val
    train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)

    model = ClimbingMLPImproved(hidden_dim=hidden_dim, dropout=dropout).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    best_val = float('inf')
    epochs_no_improve = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        for holds, globals, targets in train_loader:
            holds, globals, targets = holds.to(device), globals.to(device), targets.to(device)
            optimizer.zero_grad()
            preds = model(holds, globals).squeeze()
            loss = loss_fn(preds, targets)
            loss.backward()
            optimizer.step()
            total_train_loss += loss.item() * len(targets)

        train_mse = total_train_loss / len(train_loader.dataset)
        val_mse = dataset_mse(model, val_loader, loss_fn, device)

        if verbose:
            print(f"Epoch {epoch+1:03d}: Train MSE={train_mse:.3f}, Val MSE={val_mse:.3f}")

        if val_mse < best_val:
            best_val = val_mse
            best_model_state = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                break

    model.load_state_dict(best_model_state)
    test_mse = dataset_mse(model, test_loader, loss_fn, device)
    return best_val, test_mse


# --- Hyperparameter search ---
def hyperparam_search(dataset):
    lrs = [1e-4, 3e-4, 1e-3]
    weight_decays = [1e-4, 1e-3, 1e-2]
    hidden_dims = [128, 256, 512]
    dropouts = [0.3, 0.4, 0.5]

    search_space = list(itertools.product(lrs, weight_decays, hidden_dims, dropouts))
    results = []

    print(f"Running {len(search_space)} experiments...\n")

    for i, (lr, wd, hdim, drop) in enumerate(search_space, 1):
        print(f"▶️ [{i}/{len(search_space)}] LR={lr}, WD={wd}, H={hdim}, Drop={drop}")
        val_mse, test_mse = train_model(
            dataset,
            lr=lr,
            weight_decay=wd,
            hidden_dim=hdim,
            dropout=drop,
            num_epochs=500,
            patience=50,
            verbose=False
        )
        results.append({
            'lr': lr,
            'weight_decay': wd,
            'hidden_dim': hdim,
            'dropout': drop,
            'val_mse': val_mse,
            'test_mse': test_mse
        })
        print(f"   → val={val_mse:.3f}, test={test_mse:.3f}\n")

    # sort by best validation MSE
    results = sorted(results, key=lambda x: x['val_mse'])
    print("✅ Best Config:")
    print(results[0])
    return results

class ClimbDataset(torch.utils.data.Dataset):
    def __init__(self, hold_features, global_features, targets):
        self.hold_features = torch.tensor(hold_features, dtype=torch.float32)
        self.global_features = torch.tensor(global_features, dtype=torch.float32)
        self.targets = torch.tensor(targets, dtype=torch.float32)

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return self.hold_features[idx], self.global_features[idx], self.targets[idx]

dataset = ClimbDataset(X_global, global_features, y)

results = hyperparam_search(dataset)