# Hyperparameter sensitivity experiment
This notebook conducts a post-hoc Laplace approximation hyperparameter sensitivity analysis on a pre-trained WideResNet-16-4 model trained on CIFAR-10.

## 1. Setup
Imports and model loading.

In [1]:
import sys
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
sys.path.append('./')

from laplace import Laplace
from models.wideresnet.wideresnet import WideResNet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = WideResNet(depth=16, num_classes=10, widen_factor=4).to(device)
checkpoint = torch.load('models/wideresnet/pretrained/model_best.pth.tar', map_location='cpu')
model.load_state_dict(checkpoint['state_dict'])
model.eval()

  from .autonotebook import tqdm as notebook_tqdm
  checkpoint = torch.load('models/wideresnet/pretrained/model_best.pth.tar', map_location='cpu')


WideResNet(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (block1): NetworkBlock(
    (layer): Sequential(
      (0): BasicBlock(
        (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(16, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu2): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (convShortcut): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (1): BasicBlock(
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, 

## 2. Data preparation
Split CIFAR-10 into train/val/test.

In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010))
])
full_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_set, val_set = random_split(full_train, [45000, 5000], generator=torch.Generator().manual_seed(42))
test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_set, batch_size=256, shuffle=False, num_workers=4)
test_loader  = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=4)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:11<00:00, 15.4MB/s] 


Extracting ./data\cifar-10-python.tar.gz to ./data
Files already downloaded and verified


## 3. Default settings and hyperparameter options

In [10]:
# Default settings for one‐at‐a‐time sweeps
default_settings = {
    'prior_precision':    1.0,
    'temperature':        1.0,
    'hessian_structure':  'kron',
    'link_approx':        'mc',            # ← mc is the only valid default for nn
    'n_samples':          128,
    'joint':              False,
    'diagonal_output':    False,
    'pred_type':          'nn',
    'subset_of_weights':  'last_layer',
}

hp_options = {
    'prior_precision':    torch.logspace(-6, 2, 20).tolist(),
    'temperature':        torch.logspace(-1, 1, 10).tolist(),
    'hessian_structure':  ['diag', 'kron', 'full', 'lowrank', 'gp'],
    'link_approx':        ['probit', 'mc', 'bridge'],  # we'll switch pred_type when sweeping this
    'n_samples':          [32, 128, 512],
    'joint':              [False, True],
    'diagonal_output':    [False, True],
    'pred_type':          ['nn', 'glm', 'gp'],
    'subset_of_weights':  ['last_layer', 'subnetwork', 'all'],
}



## 4. Utility: ECE computation

In [11]:
def compute_ece(probs, labels, n_bins=15):
    confidences, predictions = torch.max(probs, 1)
    accuracies = predictions.eq(labels)
    bins = torch.linspace(0, 1, n_bins + 1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if mask.any():
            ece += (mask.float().mean() * torch.abs(accuracies[mask].float().mean() - confidences[mask].mean()))
    return ece.item()

## 5. One-at-a-time hyperparameter sweeps

In [12]:
import pandas as pd

results_ext = []
for hp_name, values in hp_options.items():
    for val in values:
        # copy & overwrite this one hyperparam
        settings = default_settings.copy()
        settings[hp_name] = val

        # if we're sweeping link_approx, move to a GLM predictor
        if hp_name == 'link_approx':
            settings['pred_type'] = 'glm'

        # nn + non‐mc is invalid → skip
        if settings['pred_type'] == 'nn' and settings['link_approx'] != 'mc':
            continue

        # now it's safe to build the Laplace
        la = Laplace(
            model, 'classification',
            subset_of_weights=settings['subset_of_weights'],
            hessian_structure=settings['hessian_structure'],
            prior_precision=settings['prior_precision'],
            temperature=settings['temperature']
        )
        la.fit(train_loader)

        # predict + metrics
        all_probs, all_targets = [], []
        with torch.no_grad():
            for x, y in test_loader:
                x = x.to(device)
                out = la(
                    x,
                    pred_type=settings['pred_type'],
                    link_approx=settings['link_approx'],
                    n_samples=settings['n_samples'],
                    joint=settings['joint'],
                    diagonal_output=settings['diagonal_output']
                )
                probs = F.softmax(out['mean'], dim=1).cpu()
                all_probs.append(probs)
                all_targets.append(y)
        all_probs = torch.cat(all_probs)
        all_targets = torch.cat(all_targets)

        nll = F.cross_entropy(torch.log(all_probs), all_targets).item()
        acc = (all_probs.argmax(1) == all_targets).float().mean().item()
        ece = compute_ece(all_probs, all_targets)

        results_ext.append({
            'hyperparam': hp_name,
            'value':      val,
            'nll':        nll,
            'accuracy':   acc,
            'ece':        ece,
        })

df_ext = pd.DataFrame(results_ext)
df_ext.to_csv('/mnt/data/hyperparam_sensitivity_extended_results.csv', index=False)
print('Extended results saved to /mnt/data/hyperparam_sensitivity_extended_results.csv')

IndexError: too many indices for tensor of dimension 2

## 6. Plotting extended sensitivity results

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

df_ext = pd.read_csv('/mnt/data/hyperparam_sensitivity_extended_results.csv')

# Numeric hyperparameters to plot as curves
for hp in ['prior_precision', 'temperature', 'n_samples']:
    sub = df_ext[df_ext['hyperparam'] == hp]
    plt.figure()
    # Ensure numeric sorting
    sub = sub.sort_values(by='value')
    plt.plot(np.log10(sub['value']) if hp in ['prior_precision','temperature'] else sub['value'],
             sub['nll'], marker='o')
    plt.xlabel('log10(value)' if hp in ['prior_precision','temperature'] else hp)
    plt.ylabel('NLL')
    plt.title(f'NLL vs {hp}')
    plt.show()

# Categorical hyperparameters as bar charts
for hp in ['hessian_structure', 'link_approx', 'joint', 'diagonal_output', 'pred_type', 'subset_of_weights']:
    sub = df_ext[df_ext['hyperparam'] == hp]
    mean_metrics = sub.groupby('value')[['nll','accuracy','ece']].mean()
    mean_metrics.plot(kind='bar', subplots=True, layout=(1,3), figsize=(12,4), legend=False, sharex=True)
    plt.suptitle(f'Metrics vs {hp}')
    plt.show()