# Hyperparameter sensitivity experiment
This notebook conducts a post-hoc Laplace approximation hyperparameter sensitivity analysis on pre-trained LeNet5 and 2-layers MLP models trained on MNIST.

## 1. Setup
Imports and model loading.

In [12]:
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Add root directory to sys.path
project_root = Path().resolve().parent
sys.path.append(str(project_root))

from laplace.laplace import Laplace
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, LowRankLaplace, DiagLaplace, FunctionalLaplace
from laplace.lllaplace import LLLaplace, FunctionalLLLaplace, DiagLLLaplace, FullLLLaplace, KronLLLaplace
from models.wideresnet.wideresnet import WideResNet
from models.lenet.lenet5 import LeNet5
from models.mlp.mlp import MLP

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seeds = [6, 12, 37, 42, 113]

# Instantiate and load models
def load_lenet_models(seeds):
    models = []
    for seed in seeds:
        model = LeNet5()
        pth = f"{project_root}/hyperparams/models/lenet/pretrained/lenet_mnist_seed{seed}.pth"
        model.load_state_dict(torch.load(pth, map_location=device))
        model.to(device).eval()
        models.append(model)
    return models

def load_mlp_models(seeds):
    models = []
    for seed in seeds:
        model = MLP()
        pth = f"{project_root}/hyperparams/models/mlp/pretrained/mlp_mnist_seed{seed}.pth"
        model.load_state_dict(torch.load(pth, map_location=device))
        model.to(device).eval()
        models.append(model)
    return models

lenet_models = load_lenet_models(seeds)
mlp_models = load_mlp_models(seeds)

#check shape
x_dummy = torch.randn(1, 1, 28, 28).to(device)  # for MNIST
print("LeNet5 output shape:", lenet_models[0](x_dummy).shape)
print("MLP output shape:", mlp_models[0](x_dummy.view(1, -1)).shape)


LeNet5 output shape: torch.Size([1, 10])
MLP output shape: torch.Size([1, 10])


  model.load_state_dict(torch.load(pth, map_location=device))
  model.load_state_dict(torch.load(pth, map_location=device))


## 2. Data preparation
Load MNIST test set, both ID (MNIST) and OOD (Fashion-MNIST).

In [16]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# ID - MNIST
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=128, shuffle=False)

mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
mnist_test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)

# OOD - Fashion MNIST
fashionmnist_test  = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)
fashionmnist_loader = DataLoader(fashionmnist_test, batch_size=128, shuffle=False)

## 3. Default settings and hyperparameter options

In [17]:
# Default settings for one‐at‐a‐time sweeps
'''
Key Hyperparameters
- prior_precision: Controls the strength of the Gaussian prior over weights.
                   Higher values imply stronger regularization.
                    Accepts:
                        Scalar (most common)
                        Per-parameter vector (for advanced use)
- temperature: Scales the log-likelihood, which affects posterior uncertainty.
               <1 makes the posterior sharper, >1 more diffuse.
               likelihood:
                  Either classification or regression.
                  Affects loss used in curvature estimation.
- hessian_structure: 'diag', 'kron', 'full', etc.
                      Major impact on runtime and accuracy.
                      Must be one of the supported structures in lllaplace.py.

- subset_of_weights: I only use 'last_layer'.
                     Others ('subnetwork', 'all') are available but much more expensive.

- backend_kwargs: For curvature approximation (e.g., using BackPACK, GGN).
                  Important for stochastic curvature approximations.
'''
    
# hp_options = {
#     'prior_precision':    [1e-6, 1e-4, 1e-2, 1.0, 100.0], #torch.logspace(-6, 2, 20).tolist()
#     'temperature':        [0.1, 0.5, 1.0, 2.0], #torch.logspace(-1, 1, 10).tolist()
#     'hessian_structure':  ['diag', 'kron', 'full', 'lowrank', 'gp'],
#     'link_approx':        ['probit', 'mc', 'bridge'],  # we'll switch pred_type when sweeping this
#     'n_samples':          [32, 128, 512],
#     'joint':              [False, True],
#     'diagonal_output':    [False, True],
#     'pred_type':          ['nn', 'glm', 'gp'],
#     'subset_of_weights':  ['last_layer'], # not using 'subnetwork' and 'all'
# }

from itertools import product

# Common across all
prior_precisions = [1e-6, 1e-4, 1e-2, 1.0, 100.0]
temperatures = [0.1, 0.5, 1.0, 2.0]
hessian_structures = ['diag', 'kron', 'full']  # 'lowrank' not supported for last_layer
subset_of_weights = ['last_layer']

# pred_type = 'nn' → uses sampling
nn_grid = list(product(
    prior_precisions,
    temperatures,
    hessian_structures,
    ['nn'],        # pred_type
    [128],         # n_samples
    [None],        # joint
    [None]         # diagonal_output
))

# pred_type = 'glm' → uses analytic mean/variance
glm_grid = list(product(
    prior_precisions,
    temperatures,
    hessian_structures,
    ['glm'],       # pred_type
    [None],        # n_samples
    [False, True], # joint
    [False, True]  # diagonal_output
))

# pred_type = 'gp' → FunctionalLLLaplace (requires separate setup)
# Not included in default grid since it requires n_subset and a different class

# Combine
full_grid = nn_grid + glm_grid

# Format as list of dicts
hp_grid = [
    {
        'prior_precision': pp,
        'temperature': temp,
        'hessian_structure': hess,
        'pred_type': pred,
        'n_samples': n_samp,
        'joint': joint,
        'diagonal_output': diag_out,
        'subset_of_weights': 'last_layer'
    }
    for pp, temp, hess, pred, n_samp, joint, diag_out in full_grid
]

print(f"Total valid configurations: {len(hp_grid)}")


Total valid configurations: 300


## 4. Utility: ECE computation

In [18]:
def compute_ece(probs, labels, n_bins=15):
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)
    bins = torch.linspace(0, 1, n_bins + 1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if mask.any():
            ece += (mask.float().mean() * torch.abs(accuracies[mask].float().mean() - confidences[mask].mean()))
    return ece.item()

## 5. One-at-a-time hyperparameter sweeps

In [12]:
import pandas as pd

results_ext = []
for hp_name, values in hp_options.items():
    for val in values:
        # copy & overwrite this one hyperparam
        settings = default_settings.copy()
        settings[hp_name] = val

        # if we're sweeping link_approx, move to a GLM predictor
        if hp_name == 'link_approx':
            settings['pred_type'] = 'glm'

        # nn + non‐mc is invalid → skip
        if settings['pred_type'] == 'nn' and settings['link_approx'] != 'mc':
            continue

        # now it's safe to build the Laplace
        la = Laplace(
            model, 'classification',
            subset_of_weights=settings['subset_of_weights'],
            hessian_structure=settings['hessian_structure'],
            prior_precision=settings['prior_precision'],
            temperature=settings['temperature']
        )
        la.fit(train_loader)

        # predict + metrics
        all_probs, all_targets = [], []
        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                out = la(
                    x,
                    pred_type=settings['pred_type'],
                    link_approx=settings['link_approx'],
                    n_samples=settings['n_samples'],
                    joint=settings['joint'],
                    diagonal_output=settings['diagonal_output']
                )
                probs = F.softmax(out['mean'], dim=1).cpu()
                all_probs.append(probs)
                all_targets.append(y)
        all_probs = torch.cat(all_probs)
        all_targets = torch.cat(all_targets)

        nll = F.cross_entropy(torch.log(all_probs), all_targets).item()
        acc = (all_probs.argmax(1) == all_targets).float().mean().item()
        ece = compute_ece(all_probs, all_targets)

        results_ext.append({
            'hyperparam': hp_name,
            'value':      val,
            'nll':        nll,
            'accuracy':   acc,
            'ece':        ece,
        })

df_ext = pd.DataFrame(results_ext)
df_ext.to_csv('/mnt/data/hyperparam_sensitivity_extended_results.csv', index=False)
print('Extended results saved to /mnt/data/hyperparam_sensitivity_extended_results.csv')

IndexError: too many indices for tensor of dimension 2

## 6. Plotting extended sensitivity results

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df_ext = pd.read_csv('/mnt/data/hyperparam_sensitivity_extended_results.csv')

# Numeric hyperparameters to plot as curves
for hp in ['prior_precision', 'temperature', 'n_samples']:
    sub = df_ext[df_ext['hyperparam'] == hp]
    plt.figure()
    # Ensure numeric sorting
    sub = sub.sort_values(by='value')
    plt.plot(np.log10(sub['value']) if hp in ['prior_precision','temperature'] else sub['value'],
             sub['nll'], marker='o')
    plt.xlabel('log10(value)' if hp in ['prior_precision','temperature'] else hp)
    plt.ylabel('NLL')
    plt.title(f'NLL vs {hp}')
    plt.show()

# Categorical hyperparameters as bar charts
for hp in ['hessian_structure', 'link_approx', 'joint', 'diagonal_output', 'pred_type', 'subset_of_weights']:
    sub = df_ext[df_ext['hyperparam'] == hp]
    mean_metrics = sub.groupby('value')[['nll','accuracy','ece']].mean()
    mean_metrics.plot(kind='bar', subplots=True, layout=(1,3), figsize=(12,4), legend=False, sharex=True)
    plt.suptitle(f'Metrics vs {hp}')
    plt.show()