# Hyperparameters sensitivity analysis

This notebook conducts a post-hoc Laplace approximation hyperparameter sensitivity analysis on pre-trained WideLeNet and ResNet18 models trained on MNIST. 

> Note: this experiment script was run on Kaggle

In [1]:
!pip install torch
!pip install numpy
!pip install torchmetrics
!pip install torchvision
!pip install tqdm
!pip install seaborn
!pip install matplotlib
!pip install backpack
!pip install curvlinops-for-pytorch
!pip install backpack-for-pytorch
!pip install asdfghjkl
!pip install pandas
!pip install git+https://git@github.com/wiseodd/asdl@asdfghjkl

Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)
  Downloading nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)
  Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)
  Downloading nvidia_nvjitlink_cu12-1

### Kaggle dataset (laplace dir) setup and device==cuda

In [None]:
# Copy everything into a writable path
!cp -r /kaggle/input/laplace-project1/laplace-project1 /kaggle/working/laplace-project1

# Prepare folder for torchvision to download datasets
import os
os.makedirs("/kaggle/working/data", exist_ok=True)

# Set up imports and root path
from pathlib import Path
import sys
project_root = Path("/kaggle/working/laplace-project1")
sys.path.append(str(project_root / "hyperparams"))   # for models
sys.path.append(str(project_root))                   # for laplace/

# Set device
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("Project root:", project_root)
print("Using device:", device)

Project root: /kaggle/working/laplace-project1
Using device: cuda


### Imports and loading of pretrained models (all seeds)

In [3]:
import torch.nn.functional as F
import random
import pandas as pd
from torchvision import datasets, transforms
from torchvision.transforms import functional as TF
from torch.utils.data import DataLoader, random_split
from laplace.laplace import Laplace
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, LowRankLaplace, DiagLaplace, FunctionalLaplace
from laplace.lllaplace import LLLaplace, FunctionalLLLaplace, DiagLLLaplace, FullLLLaplace, KronLLLaplace
from models.lenet.lenet5 import LeNet5
from models.wideresnet.wideresnet import WideResNet
from models.mlp.mlp import MLP
from models.widelenet.widelenet import WideLeNet
from models.resnet.resnet18 import ResNet18

seeds = [0, 42, 123]

# Instantiate and load models
def load_widelenet_models(seeds):
    models = []
    for seed in seeds:
        model = WideLeNet()
        pth = f"{project_root}/hyperparams/models/widelenet/pretrained/widenet_seed{seed}.pth"
        model.load_state_dict(torch.load(pth, map_location=device))
        model.to(device).eval()
        models.append(model)
    return models

def load_resnet_models(seeds):
    models = []
    for seed in seeds:
        model = ResNet18()
        pth = f"{project_root}/hyperparams/models/resnet/pretrained/resnet18_mnist_seed{seed}.pth"
        model.load_state_dict(torch.load(pth, map_location=device))
        model.to(device).eval()
        models.append(model)
    return models

widelenet_models = load_widelenet_models(seeds)
resnet_models = load_resnet_models(seeds)

#check shape
x_dummy = torch.randn(1, 1, 32, 32).to(device)  # for MNIST
print("WideLenet output shape:", widelenet_models[0](x_dummy).shape)
print("Resnet output shape:", resnet_models[0](x_dummy).shape)




WideLenet output shape: torch.Size([1, 10])
Resnet output shape: torch.Size([1, 10])


### Rotate MNIST

In [4]:
class FixedRotation:
    def __init__(self, angle):
        self.angle = angle

    def __call__(self, x):
        x = TF.rotate(x, self.angle)  # deterministic
        x = transforms.Resize(32)(x)
        x = transforms.ToTensor()(x)
        x = transforms.Normalize((0.1307,), (0.3081,))(x)
        return x

### Download MNIST dataset and Rotate it to create RMNIST for all specified angles

In [5]:
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

os.makedirs("/kaggle/working/data", exist_ok=True)
# ID - MNIST
mnist_train = datasets.MNIST(root="/kaggle/working/data", train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=128, shuffle=False)

mnist_test = datasets.MNIST(root="/kaggle/working/data", train=False, download=True, transform=transform)
mnist_test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)

# OOD - Fashion MNIST
fashionmnist_test  = datasets.FashionMNIST(root="/kaggle/working/data", train=False, download=True, transform=transform)
fashionmnist_loader = DataLoader(fashionmnist_test, batch_size=128, shuffle=False)

# OOD - RMNIST
rotation_angles = [5, 15, 30, 45, 60, 90, 120, 160, 180]
rotated_loaders = {}

for angle in rotation_angles:
    transform = FixedRotation(angle)
    ds = datasets.MNIST(root="/kaggle/working/data", train=False, download=True, transform=transform)
    rotated_loaders[angle] = DataLoader(ds, batch_size=128, shuffle=False)


100%|██████████| 9.91M/9.91M [00:00<00:00, 52.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.75MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.93MB/s]
100%|██████████| 26.4M/26.4M [00:01<00:00, 13.3MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 303kB/s]
100%|██████████| 4.42M/4.42M [00:00<00:00, 5.53MB/s]
100%|██████████| 5.15k/5.15k [00:00<00:00, 11.8MB/s]


### ECE and NLL computation functions

In [6]:
# Metrics: ECE and NLL
def compute_ece(probs, labels, n_bins=15):
    confidences, preds = probs.max(1)
    accuracies = preds.eq(labels)
    bins = torch.linspace(0, 1, n_bins + 1, device=probs.device)
    ece = torch.tensor(0., device=probs.device)
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if mask.any():
            ece += mask.float().mean() * (accuracies[mask].float().mean() - confidences[mask].mean()).abs()
    return ece.item()

def compute_nll(probs: torch.Tensor, labels: torch.Tensor) -> float:
    """
    Computes the Negative Log-Likelihood (NLL) given class probabilities and ground-truth labels.
    Assumes `probs` is of shape [batch_size, num_classes] with valid probabilities (summing to 1).
    """
    log_probs = torch.log(probs + 1e-12)  # prevent log(0)
    nll = -log_probs[torch.arange(labels.size(0), device=labels.device), labels]
    return nll.mean().item()

### Hyperparams choices

In [None]:
'''
Discarded draft: too many combinations (not feasible in my PC)

# 1. Common hyperparameter values 
prior_precisions   = [1e-6, 1e-2, 1.0, 100.0]
temperatures       = [0.1, 1.0, 2.0]
# For last‐layer non‐GP: use diag/kron/full
hess_structs_main  = ['diag', 'kron', 'full']
# For functional‐GP: only 'gp'
hess_structs_gp    = ['gp']

# 2. Additional options for each pred_type
# 'nn' uses MC sampling
n_samples_nn       = [32, 128, 512]

# 'glm' uses analytic mean/variance plus link‐approx
link_options_glm   = ['probit', 'bridge']
joint_options      = [False, True]
diag_out_options   = [False, True]
# I'll set n_samples=1 for analytic mode

# 'gp' (FunctionalLLLaplace) requires an 'n_subset' choice
n_subset_options   = [256, 512, 1024]
# GP also supports link_approx ∈ {'probit','bridge'} and joint/diagonal_output
# and can do MC‐sampling if n_samples>1 (we’ll leave n_samples=1 here for the analytic predictive)

# 3) Build the hyperparameter grid
hp_grid = []

# pred_type = 'nn'
for pp, temp, hess, n_samp in product(
    prior_precisions,
    temperatures,
    hess_structs_main,
    n_samples_nn
):
    hp_grid.append({
        'prior_precision':   pp,
        'temperature':       temp,
        'hessian_structure': hess,         # one of 'diag','kron','full'
        'pred_type':         'nn',
        'n_samples':         n_samp,       # Monte Carlo draws
        'link_approx':       'mc',         # forced for 'nn'
        'joint':             None,         # ignored for 'nn'
        'diagonal_output':   None,         # ignored for 'nn'
        'subset_of_weights': 'last_layer'
    })

# pred_type = 'glm'
for pp, temp, hess, link, joint_flag, diag_out in product(
    prior_precisions,
    temperatures,
    hess_structs_main,
    link_options_glm,
    joint_options,
    diag_out_options
):
    hp_grid.append({
        'prior_precision':   pp,
        'temperature':       temp,
        'hessian_structure': hess,         # one of 'diag','kron','full'
        'pred_type':         'glm',
        'n_samples':         1,            # analytic GLM; set >1 if you want MC‐samples
        'link_approx':       link,         # 'probit' or 'bridge'
        'joint':             joint_flag,   # whether to compute full C×C covariance
        'diagonal_output':   diag_out,     # whether to return only diag of functional‐cov
        'subset_of_weights': 'last_layer'
    })

# pred_type = 'gp'
for pp, temp, n_sub, link, joint_flag, diag_out in product(
    prior_precisions,
    temperatures,
    n_subset_options,
    link_options_glm,   # GP predictive can also choose 'probit' or 'bridge'
    joint_options,
    diag_out_options
):
    hp_grid.append({
        'prior_precision':   pp,
        'temperature':       temp,
        'hessian_structure': 'gp',         # triggers FunctionalLLLaplace
        'pred_type':         'gp',
        'n_subset':          n_sub,        # # of points to build the GP covariance
        'n_samples':         1,            # analytic GP; set >1 if you want MC‐samples
        'link_approx':       link,         # 'probit' or 'bridge'
        'joint':             joint_flag,   # full functional‐covariance if True
        'diagonal_output':   diag_out,     # use only diagonal of functional‐cov if True
        'subset_of_weights': 'last_layer'
    })

print(f"Total valid configurations: {len(hp_grid)}")
'''

# Hyperparameter grid: prior_precision & hessian_structure, random sample ~25 combos
prior_precisions = [1e-6, 1e-4, 1e-2, 1.0, 10.0, 100.0]
temperatures = [0.1, 0.5, 1.0]
hessian_structs  = ['diag', 'kron']
link_approxs = ['probit', 'bridge']

all_configs = [
    {
      'prior_precision': pp,
      'hessian_structure': hess,
      'temperature': T,
      'link_approx': L,
      'subset_of_weights': 'last_layer'
    }
    for pp in prior_precisions
    for hess in hessian_structs
    for T in temperatures
    for L in link_approxs
]

hp_grid = all_configs
print(f"Sampling {len(hp_grid)} configurations for sensitivity analysis")


Sampling 72 configurations for sensitivity analysis


### Experiments

In [None]:
import time

# Run experiments (full config sweep: pp, hess, temp, link) with Accuracy & Brier
models_dict = {
    "widelenet": widelenet_models,
    "resnet18": resnet_models
}

output_dir = Path("/kaggle/working/laplace_outputs")
output_dir.mkdir(parents=True, exist_ok=True)

results = []
total_runs = len(models_dict) * len(hp_grid) * len(rotation_angles) * len(seeds)
run_idx = 0

start_all = time.time()
for arch, model_list in models_dict.items():
    for config in hp_grid:
        pp    = config['prior_precision']
        hess  = config['hessian_structure']
        temp  = config['temperature']
        link  = config['link_approx']

        for base_idx, base in enumerate(model_list, start=1):
            run_idx += 1
            seed = seeds[base_idx-1]
            print(f"[{run_idx}/{total_runs}] arch={arch}, seed={seed}, "
                  f"pp={pp}, hess={hess}, temp={temp}, link={link}")

            # Load & fit once per (arch, seed, config) 
            m = type(base)().to(device)
            m.load_state_dict(base.state_dict())
            m.eval()
            la = Laplace(
                m,
                likelihood='classification',
                subset_of_weights='last_layer',
                hessian_structure=hess,
                prior_precision=pp,
                temperature=temp,
                enable_backprop=False
            )
            t0 = time.time()
            la.fit(mnist_train_loader)
            print(f"  Fit done in {time.time() - t0:.1f}s")

            # ID eval
            t0 = time.time()
            probs_id, labels_id = [], []
            with torch.no_grad():
                for X, y in mnist_test_loader:
                    p = la(X.to(device), pred_type='glm', link_approx=link, n_samples=1)
                    probs_id.append(p.cpu()); labels_id.append(y)
            probs_id = torch.cat(probs_id)
            labels_id = torch.cat(labels_id)
            nll_id = compute_nll(probs_id, labels_id)
            ece_id = compute_ece(probs_id, labels_id)
            acc_id = (probs_id.argmax(1) == labels_id).float().mean().item()
            one_hot = F.one_hot(labels_id, num_classes=probs_id.size(-1)).float()
            brier_id = ((probs_id - one_hot)**2).sum(1).mean().item()
            print(f"  ID eval done in {time.time() - t0:.1f}s")

            # OOD eval: rotated MNIST
            for angle in rotation_angles:
                t0 = time.time()
                probs_ood, labels_ood = [], []
                with torch.no_grad():
                    loader = rotated_loaders[angle]
                    for X, y in loader:
                        p = la(X.to(device), pred_type='glm', link_approx=link, n_samples=1)
                        probs_ood.append(p.cpu()); labels_ood.append(y)
                probs_ood = torch.cat(probs_ood)
                labels_ood = torch.cat(labels_ood)
                nll_ood = compute_nll(probs_ood, labels_ood)
                ece_ood = compute_ece(probs_ood, labels_ood)
                acc_ood = (probs_ood.argmax(1) == labels_ood).float().mean().item()
                one_hot_ood = F.one_hot(labels_ood, num_classes=probs_ood.size(-1)).float()
                brier_ood = ((probs_ood - one_hot_ood)**2).sum(1).mean().item()
                print(f"    OOD {angle}° eval done in {time.time() - t0:.1f}s")

                # Append result row
                results.append({
                    "arch": arch,
                    "seed": seed,
                    "prior_precision": pp,
                    "hessian_structure": hess,
                    "temperature": temp,
                    "link_approx": link,
                    "rotation": angle,
                    "nll_id":    nll_id,
                    "ece_id":    ece_id,
                    "acc_id":    acc_id,
                    "brier_id":  brier_id,
                    "nll_ood":   nll_ood,
                    "ece_ood":   ece_ood,
                    "acc_ood":   acc_ood,
                    "brier_ood": brier_ood
                })

            # Clean up
            del m, la
            torch.cuda.empty_cache()

            # FLUSH partial results
            # Save current results to CSV after each seed finishes
            pd.DataFrame(results).to_csv(output_dir / "sensitivity_results.csv", index=False)
            print(f"  Flushed {len(results)} rows to {output_dir/'sensitivity_results.csv'}")

# Final save
df = pd.DataFrame(results)
df.to_csv(output_dir / "sensitivity_results.csv", index=False)
print(f"Experiment complete in {time.time() - start_all:.1f}s")
print(f"Final results saved to {output_dir/'sensitivity_results.csv'}")


[1/3888] arch=widelenet, seed=0, pp=1e-06, hess=diag, temp=0.1, link=probit
  Fit done in 16.4s
  ID eval done in 2.9s
    OOD 5° eval done in 4.0s
    OOD 15° eval done in 3.9s
    OOD 30° eval done in 4.0s
    OOD 45° eval done in 3.9s
    OOD 60° eval done in 3.8s
    OOD 90° eval done in 3.6s
    OOD 120° eval done in 3.9s
    OOD 160° eval done in 3.9s
    OOD 180° eval done in 3.5s
  Flushed 9 rows to /kaggle/working/laplace_outputs/sensitivity_results.csv
[2/3888] arch=widelenet, seed=42, pp=1e-06, hess=diag, temp=0.1, link=probit
  Fit done in 16.1s
  ID eval done in 2.7s
    OOD 5° eval done in 3.9s
    OOD 15° eval done in 4.0s
    OOD 30° eval done in 3.9s
    OOD 45° eval done in 3.9s
    OOD 60° eval done in 3.9s
    OOD 90° eval done in 3.5s
    OOD 120° eval done in 3.9s
    OOD 160° eval done in 3.9s
    OOD 180° eval done in 3.5s
  Flushed 18 rows to /kaggle/working/laplace_outputs/sensitivity_results.csv
[3/3888] arch=widelenet, seed=123, pp=1e-06, hess=diag, temp=0.1