# Hyperparameter sensitivity experiment
This notebook conducts a post-hoc Laplace approximation hyperparameter sensitivity analysis on pre-trained LeNet5 and 2-layers MLP models trained on MNIST.

## 1. Setup
Imports and model loading.

In [None]:
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

# Add root directory to sys.path
project_root = Path().resolve().parent
sys.path.append(str(project_root))

from laplace.laplace import Laplace
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, LowRankLaplace, DiagLaplace, FunctionalLaplace
from laplace.lllaplace import LLLaplace, FunctionalLLLaplace, DiagLLLaplace, FullLLLaplace, KronLLLaplace
from models.wideresnet.wideresnet import WideResNet
from models.lenet.lenet5 import LeNet5
from models.mlp.mlp import MLP
from models.widelenet.widelenet import WideLeNet
from models.resnet.resnet18 import ResNet18

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

seeds = [12, 37, 42]

# Instantiate and load models
def load_widelenet_models(seeds):
    models = []
    for seed in seeds:
        model = LeNet5()
        pth = f"{project_root}/hyperparams/models/lenet/pretrained/lenet_mnist_seed{seed}.pth"
        model.load_state_dict(torch.load(pth, map_location=device))
        model.to(device).eval()
        models.append(model)
    return models

def load_mlp_models(seeds):
    models = []
    for seed in seeds:
        model = MLP()
        pth = f"{project_root}/hyperparams/models/mlp/pretrained/mlp_mnist_seed{seed}.pth"
        model.load_state_dict(torch.load(pth, map_location=device))
        model.to(device).eval()
        models.append(model)
    return models

lenet_models = load_lenet_models(seeds)
mlp_models = load_mlp_models(seeds)

#check shape
x_dummy = torch.randn(1, 1, 28, 28).to(device)  # for MNIST
print("LeNet5 output shape:", lenet_models[0](x_dummy).shape)
print("MLP output shape:", mlp_models[0](x_dummy.view(1, -1)).shape)


  from .autonotebook import tqdm as notebook_tqdm


LeNet5 output shape: torch.Size([1, 10])
MLP output shape: torch.Size([1, 10])


  model.load_state_dict(torch.load(pth, map_location=device))
  model.load_state_dict(torch.load(pth, map_location=device))


## 2. Data preparation
Load MNIST test set, both ID (MNIST) and OOD (Fashion-MNIST).

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# ID - MNIST
mnist_train = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
mnist_train_loader = DataLoader(mnist_train, batch_size=128, shuffle=False)

mnist_test = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
mnist_test_loader = DataLoader(mnist_test, batch_size=128, shuffle=False)

# OOD - Fashion MNIST
fashionmnist_test  = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)
fashionmnist_loader = DataLoader(fashionmnist_test, batch_size=128, shuffle=False)

## 3. Default settings and hyperparameter options

In [5]:
# hp_options = {
#     'prior_precision':    [1e-6, 1e-4, 1e-2, 1.0, 100.0], #torch.logspace(-6, 2, 20).tolist()
#     'temperature':        [0.1, 0.5, 1.0, 2.0], #torch.logspace(-1, 1, 10).tolist()
#     'hessian_structure':  ['diag', 'kron', 'full', 'lowrank', 'gp'],
#     'link_approx':        ['probit', 'mc', 'bridge'],  # we'll switch pred_type when sweeping this
#     'n_samples':          [32, 128, 512],
#     'joint':              [False, True],
#     'diagonal_output':    [False, True],
#     'pred_type':          ['nn', 'glm', 'gp'],
#     'subset_of_weights':  ['last_layer'], # not using 'subnetwork' and 'all'
# }

# from itertools import product

# # Common across all
# prior_precisions = [1e-6, 1e-4, 1e-2, 1.0, 100.0]
# temperatures = [0.1, 0.5, 1.0, 2.0]
# hessian_structures = ['diag', 'kron', 'full']  # 'lowrank' not supported for last_layer
# subset_of_weights = ['last_layer']

# # pred_type = 'nn' → uses sampling
# nn_grid = list(product(
#     prior_precisions,
#     temperatures,
#     hessian_structures,
#     ['nn'],        # pred_type
#     [128],         # n_samples
#     [None],        # joint
#     [None]         # diagonal_output
# ))

# # pred_type = 'glm' → uses analytic mean/variance
# glm_grid = list(product(
#     prior_precisions,
#     temperatures,
#     hessian_structures,
#     ['glm'],       # pred_type
#     [None],        # n_samples
#     [False, True], # joint
#     [False, True]  # diagonal_output
# ))

# # pred_type = 'gp' → FunctionalLLLaplace (requires separate setup)
# # Not included in default grid since it requires n_subset and a different class

# # Combine
# full_grid = nn_grid + glm_grid

# # Format as list of dicts
# hp_grid = [
#     {
#         'prior_precision': pp,
#         'temperature': temp,
#         'hessian_structure': hess,
#         'pred_type': pred,
#         'n_samples': n_samp,
#         'joint': joint,
#         'diagonal_output': diag_out,
#         'subset_of_weights': 'last_layer'
#     }
#     for pp, temp, hess, pred, n_samp, joint, diag_out in full_grid
# ]

# print(f"Total valid configurations: {len(hp_grid)}")


In [6]:
'''
1) prior_precision
   - The precision (inverse variance) of the Gaussian prior placed 
     on the last-layer weights (θ). A higher prior_precision means a stronger 
     regularization (i.e., the prior is narrower), while a lower value puts 
     less weight on the prior and more on the data likelihood.
   - Example values: 1e-6 (very weak prior), 1e-4, 1e-2, 1.0, 100.0 (very strong prior).

2) temperature
   - A scaling factor τ applied to the likelihood's Hessian.
     Equivalently, it rescales the curvature that the Laplace posterior uses. 
     - τ < 1.0 “sharpens” the likelihood, making the posterior narrower (underestimate uncertainty).
     - τ > 1.0 “flattens” the likelihood, making the posterior wider (overestimate uncertainty).
   - Example values: 0.1 (sharper), 0.5, 1.0 (no rescaling), 2.0 (flatter).

3) hessian_structure
   - Chooses how we approximate or represent the Hessian (second derivative 
     of the negative log-likelihood) restricted to the last-layer parameters. Only 
     these options are valid when subset_of_weights='last_layer':
     - 'diag' → Diagonal approximation of the Hessian. Keeps only per-weight variances.
     - 'kron' → Kronecker-factorized approximation (blockwise over input/output dimensions). 
                  More accurate than diag but still efficient.
     - 'full' → Exact (dense) Hessian on the last-layer. Most accurate but computationally costly.
     - 'gp'   → Functional-GP approximation of the last-layer. Internally builds a Gaussian 
                  process on top of fixed features, using a subset of training points (n_subset). 
     (Note: 'lowrank' is not valid for last_layer)

4) pred_type
   - Determines how we build the predictive distribution p(y*|x*) from the Laplace posterior:
     - 'nn'  → “Neural‐network sampling” branch. We draw n_samples weight vectors θ ∼ N(θ_MAP, Σ) 
                and pass each through the network to estimate predictive probabilities. This is 
                fully Monte Carlo (MC)–based.
     - 'glm' → “Generalized linear model” branch. We approximate the last-layer outputs f* as a 
                Gaussian (mean + variance) and then apply a closed-form (e.g. probit or bridge) 
                approximation to get p(y*|f*). Can also draw MC samples of f* (if n_samples>1) 
                but typically uses analytic formulas to compute predictive mean/variance.
     - 'gp'  → “Functional-GP” branch (requires hessian_structure='gp'). We treat the last-layer 
                as an exact Gaussian process on top of fixed features, using n_subset training points 
                to build a finite GP approximation. Prediction can be analytic or MC-based.

5) n_samples
   - Number of Monte Carlo samples to draw when pred_type='nn' or when you explicitly 
     want to draw MC samples in the 'glm' or 'gp' branches. 
     - If pred_type='nn', n_samples≥1 denotes how many weight draws θ to sample from the Laplace 
       posterior. Each θ is a full last-layer parameter sample, so n_samples controls sampling variance.
     - If pred_type='glm', setting n_samples=1 means “use analytic mean/variance → no sampling.” 
       If n_samples>1, you first compute f_mu and f_var, then draw f ∼ N(f_mu, f_var) n_samples times 
       and average predictive probabilities.
     - If pred_type='gp', setting n_samples=1 means “use analytic GP predictive.” If n_samples>1, 
       you can draw posterior samples of f* from the GP to approximate predictive uncertainty.
   - Example values for 'nn': [32, 128, 512]. For 'glm' or 'gp', keep n_samples=1 or set >1 when desired.

6) link_approx
   - Only relevant when pred_type is 'glm' or 'gp'. Chooses how to approximate the 
     multi-class softmax integral under a Gaussian posterior on f*.
     - 'probit' → Extended probit approximation that gives a closed-form estimate of p(y*|f_mu, f_var). Fast and common.
     - 'bridge' → “Laplace-bridge” (Dirichlet‐Laplace) approximation to the same integral. 
       Typically more accurate in high uncertainty regimes but slightly more expensive.
     - 'mc'     → Only valid if pred_type='nn'. Means “sample logits → apply softmax → average.” 
       Not used in the 'glm' or 'gp' branches.

7) joint
   - Only applies when pred_type='glm' or pred_type='gp'. Controls whether we return 
     the **full** C×C covariance matrix of the multi-class predictive distribution (joint=True) 
     or only the diagonal variances (joint=False).  
     - If joint=True, the predictive code will compute and optionally sample from the full 
       covariance across all classes at once (cost ∼ C^2 per input).  
     - If joint=False, only the diagonal of that covariance is used → cost ∼ C per input.  
     - If pred_type='nn', joint is ignored (set to None to indicate “not applicable”).

8) diagonal_output
   - Also only applies when pred_type='glm' or 'gp'. Controls whether, after sampling 
     from a full covariance, you drop off-diagonal terms and only keep per-class variances.  
     - If diagonal_output=True, you force the predictive code to treat any per-class covariance as 
       diagonal (i.e. ignore cross-class covariances).  
     - If diagonal_output=False, you keep the full covariance matrix.  
     - If pred_type='nn', diagonal_output is ignored (set to None).

9) n_subset
   - Only used when hessian_structure='gp' (Functional-GP). Specifies how many 
     training points to randomly subsample (or approximate in blocks) to form the GP’s Gram matrix.  
     - Smaller n_subset → cheaper fitting/inference (GP uses a low-rank approximation with n_subset points).  
     - Larger n_subset → more accurate GP posterior but higher cost.  
     - If hessian_structure!='gp', set n_subset=None (ignored).
   - Example values: [256, 512, 1024].

10) subset_of_weights
    - Determines which parameters of the neural network to place a Laplace posterior on.  
      For this entire script, I fix subset_of_weights='last_layer', meaning I only treat the final 
      linear layer as random under a Gaussian posterior. All earlier layers remain 
      point estimates (no uncertainty).
    - Valid values (in general library): 'all' (full‐network), 'subnetwork' (user‐specified layers), 
      'last_layer' (only the last linear).  
    - By setting 'last_layer', I ensure the Laplace machinery only builds Hessians or GP covariances 
      over the final layer’s weight and bias.
'''

"\n1) prior_precision\n   - The precision (inverse variance) of the Gaussian prior placed \n     on the last-layer weights (θ). A higher prior_precision means a stronger \n     regularization (i.e., the prior is narrower), while a lower value puts \n     less weight on the prior and more on the data likelihood.\n   - Example values: 1e-6 (very weak prior), 1e-4, 1e-2, 1.0, 100.0 (very strong prior).\n\n2) temperature\n   - A scaling factor τ applied to the likelihood's Hessian.\n     Equivalently, it rescales the curvature that the Laplace posterior uses. \n     - τ < 1.0 “sharpens” the likelihood, making the posterior narrower (underestimate uncertainty).\n     - τ > 1.0 “flattens” the likelihood, making the posterior wider (overestimate uncertainty).\n   - Example values: 0.1 (sharper), 0.5, 1.0 (no rescaling), 2.0 (flatter).\n\n3) hessian_structure\n   - Chooses how we approximate or represent the Hessian (second derivative \n     of the negative log-likelihood) restricted to the la

In [11]:
# #456 configurations

# from itertools import product

# # 1. Common hyperparameter values 
# prior_precisions   = [1e-6, 1e-2, 1.0, 100.0]
# temperatures       = [0.5, 1.0]
# # For last‐layer non‐GP: use diag/kron/full
# hess_structs_main  = ['diag', 'kron', 'full']
# # For functional‐GP: only 'gp'
# hess_structs_gp    = ['gp']

# # 2. Additional options for each pred_type
# # 'nn' uses MC sampling
# n_samples_nn       = [32, 128, 512]

# # 'glm' uses analytic mean/variance plus link‐approx
# link_options_glm   = ['probit', 'bridge']
# joint_options      = [False, True]
# diag_out_options   = [False, True]
# # I'll set n_samples=1 for analytic mode

# # 'gp' (FunctionalLLLaplace) requires an 'n_subset' choice
# n_subset_options   = [256, 512, 1024]
# # GP also supports link_approx ∈ {'probit','bridge'} and joint/diagonal_output
# # and can do MC‐sampling if n_samples>1 (we’ll leave n_samples=1 here for the analytic predictive)

# # 3) Build the hyperparameter grid
# hp_grid = []

# # pred_type = 'nn'
# for pp, temp, hess, n_samp in product(
#     prior_precisions,
#     temperatures,
#     hess_structs_main,
#     n_samples_nn
# ):
#     hp_grid.append({
#         'prior_precision':   pp,
#         'temperature':       temp,
#         'hessian_structure': hess,         # one of 'diag','kron','full'
#         'pred_type':         'nn',
#         'n_samples':         n_samp,       # Monte Carlo draws
#         'link_approx':       'mc',         # forced for 'nn'
#         'joint':             None,         # ignored for 'nn'
#         'diagonal_output':   None,         # ignored for 'nn'
#         'subset_of_weights': 'last_layer'
#     })

# # pred_type = 'glm'
# for pp, temp, hess, link, joint_flag, diag_out in product(
#     prior_precisions,
#     temperatures,
#     hess_structs_main,
#     link_options_glm,
#     joint_options,
#     diag_out_options
# ):
#     hp_grid.append({
#         'prior_precision':   pp,
#         'temperature':       temp,
#         'hessian_structure': hess,         # one of 'diag','kron','full'
#         'pred_type':         'glm',
#         'n_samples':         1,            # analytic GLM; set >1 if you want MC‐samples
#         'link_approx':       link,         # 'probit' or 'bridge'
#         'joint':             joint_flag,   # whether to compute full C×C covariance
#         'diagonal_output':   diag_out,     # whether to return only diag of functional‐cov
#         'subset_of_weights': 'last_layer'
#     })

# # pred_type = 'gp'
# for pp, temp, n_sub, link, joint_flag, diag_out in product(
#     prior_precisions,
#     temperatures,
#     n_subset_options,
#     link_options_glm,   # GP predictive can also choose 'probit' or 'bridge'
#     joint_options,
#     diag_out_options
# ):
#     hp_grid.append({
#         'prior_precision':   pp,
#         'temperature':       temp,
#         'hessian_structure': 'gp',         # triggers FunctionalLLLaplace
#         'pred_type':         'gp',
#         'n_subset':          n_sub,        # # of points to build the GP covariance
#         'n_samples':         1,            # analytic GP; set >1 if you want MC‐samples
#         'link_approx':       link,         # 'probit' or 'bridge'
#         'joint':             joint_flag,   # full functional‐covariance if True
#         'diagonal_output':   diag_out,     # use only diagonal of functional‐cov if True
#         'subset_of_weights': 'last_layer'
#     })

# print(f"Total valid configurations: {len(hp_grid)}")

from itertools import product

# Common hyperparameter values
prior_precisions  = [1e-6, 1e-2, 1.0, 100.0]
temperatures      = [0.5, 1.0]
# For last‐layer (‘subset_of_weights=last_layer’), we only keep ‘diag’ and ‘kron’
hess_structs_main = ['diag', 'kron']

# 2. GLM link‐approx options
link_options_glm = ['probit', 'bridge']
diag_out_options = [False, True]
# We only keep joint=False for simplicity (no full output covariance)
joint_options = [False]

# 3. Build the shrunken hyperparameter grid:
#    (no “nn+MC sampling” block)
hp_grid = []
for pp, temp, hess, link, diag_out, joint in product(
        prior_precisions,
        temperatures,
        hess_structs_main,
        link_options_glm,
        diag_out_options,
        joint_options
    ):
    hp_grid.append({
        'prior_precision':   pp,
        'temperature':       temp,
        'hessian_structure': hess,         # 'diag' or 'kron'
        'pred_type':         'glm',        # always analytic GLM for last‐layer
        'n_samples':         1,            # analytic mode
        'link_approx':       link,         # 'probit' or 'bridge'
        'joint':             joint,        # False
        'diagonal_output':   diag_out,     # True/False
        'subset_of_weights': 'last_layer'
    })

print(f"Total valid GLM configurations: {len(hp_grid)}")

Total valid GLM configurations: 64


## 4. Utility: ECE computation

In [10]:
def compute_ece(probs: torch.Tensor, labels: torch.Tensor, n_bins: int = 15) -> float:
    """
    Compute expected calibration error (ECE) for classification.
    """
    confidences, predictions = torch.max(probs, dim=1)
    accuracies = predictions.eq(labels)
    bins = torch.linspace(0, 1, n_bins + 1, device=probs.device)
    ece = torch.zeros(1, device=probs.device)
    for i in range(n_bins):
        # find predictions with confidence in (bins[i], bins[i+1]]
        mask = (confidences > bins[i]) & (confidences <= bins[i + 1])
        if mask.any():
            # weight by fraction of data points in this bin
            ece += (mask.float().mean() * 
                    torch.abs(accuracies[mask].float().mean() - confidences[mask].mean()))
    return ece.item()

In [None]:
# import pandas as pd
# import os

# results_id = []
# results_ood = []

# arch_model_lists = {
#     "lenet": lenet_models,
#     "mlp":   mlp_models,
# }

# # Freeze gradients
# for model_list in arch_model_lists.values():
#     for m in model_list:
#         m.requires_grad_(False)

# # Compute total number of (architecture × hyperparameter) configurations
# num_archs = len(arch_model_lists)
# num_hps   = len(hp_grid)
# total_configs = num_archs * num_hps
# config_counter = 0

# # Loop over architectures and hyperparameter configurations
# for model_name, models_list in arch_model_lists.items():
#     for config in hp_grid:
#         pp          = config['prior_precision']
#         temp        = config['temperature']
#         hess        = config['hessian_structure']
#         pred_type   = config['pred_type']
#         n_samp      = config.get('n_samples', None)
#         link        = config.get('link_approx', None)
#         joint_flag  = config.get('joint', None)
#         diag_out    = config.get('diagonal_output', None)
#         n_sub       = config.get('n_subset', None)

#         # Accumulators for metrics across seeds
#         acc_id_list  = []
#         ece_id_list  = []
#         acc_ood_list = []
#         ece_ood_list = []

#         # Loop over seeds
#         for base_model in models_list:
#             # Copy pretrained model
#             model_copy = type(base_model)().to(device)
#             model_copy.load_state_dict(base_model.state_dict())
#             model_copy.eval()

#             # Instantiate Laplace or FunctionalLLLaplace
#             if hess == 'gp':
#                 la = FunctionalLLLaplace(
#                     model_copy,
#                     'classification',
#                     n_subset=n_sub,
#                     prior_precision=pp,
#                     temperature=temp,
#                     enable_backprop=False
#                 )
#             else:
#                 la = Laplace(
#                     model_copy,
#                     'classification',
#                     subset_of_weights='last_layer',
#                     hessian_structure=hess,
#                     prior_precision=pp,
#                     temperature=temp,
#                     enable_backprop=False
#                 )

#             # Fit on MNIST train
#             la.fit(mnist_train_loader)

#             # Evaluate ID
#             all_probs_id  = []
#             all_labels_id = []
#             with torch.no_grad():
#                 for X, y in mnist_test_loader:
#                     X, y = X.to(device), y.to(device)

#                     if pred_type == 'nn':
#                         probs = la(
#                             X,
#                             pred_type='nn',
#                             link_approx='mc',
#                             n_samples=n_samp
#                         )
#                     elif pred_type == 'glm':
#                         probs = la(
#                             X,
#                             pred_type='glm',
#                             joint=joint_flag,
#                             link_approx=link,
#                             n_samples=1,
#                             diagonal_output=diag_out
#                         )
#                     else:  # 'gp'
#                         probs = la(
#                             X,
#                             pred_type='gp',
#                             joint=joint_flag,
#                             link_approx=link,
#                             n_samples=1,
#                             diagonal_output=diag_out
#                         )

#                     all_probs_id.append(probs.cpu())
#                     all_labels_id.append(y.cpu())

#             all_probs_id  = torch.cat(all_probs_id, dim=0)
#             all_labels_id = torch.cat(all_labels_id, dim=0)

#             _, preds_id = torch.max(all_probs_id, dim=1)
#             acc_id = (preds_id == all_labels_id).float().mean().item()
#             ece_id = compute_ece(all_probs_id, all_labels_id)

#             # Evaluate OOD
#             all_probs_ood  = []
#             all_labels_ood = []
#             with torch.no_grad():
#                 for X, y in fashionmnist_loader:
#                     X, y = X.to(device), y.to(device)

#                     if pred_type == 'nn':
#                         probs = la(
#                             X,
#                             pred_type='nn',
#                             link_approx='mc',
#                             n_samples=n_samp
#                         )
#                     elif pred_type == 'glm':
#                         probs = la(
#                             X,
#                             pred_type='glm',
#                             joint=joint_flag,
#                             link_approx=link,
#                             n_samples=1,
#                             diagonal_output=diag_out
#                         )
#                     else:  # 'gp'
#                         probs = la(
#                             X,
#                             pred_type='gp',
#                             joint=joint_flag,
#                             link_approx=link,
#                             n_samples=1,
#                             diagonal_output=diag_out
#                         )

#                     all_probs_ood.append(probs.cpu())
#                     all_labels_ood.append(y.cpu())

#             all_probs_ood  = torch.cat(all_probs_ood, dim=0)
#             all_labels_ood = torch.cat(all_labels_ood, dim=0)

#             _, preds_ood = torch.max(all_probs_ood, dim=1)
#             acc_ood = (preds_ood == all_labels_ood).float().mean().item()
#             ece_ood = compute_ece(all_probs_ood, all_labels_ood)

#             # Collect this seed’s metrics
#             acc_id_list.append(acc_id)
#             ece_id_list.append(ece_id)
#             acc_ood_list.append(acc_ood)
#             ece_ood_list.append(ece_ood)

#             # Free memory
#             del la, model_copy, all_probs_id, all_probs_ood
#             torch.cuda.empty_cache()

#         # Compute mean and std across seeds
#         acc_id_mean  = float(torch.tensor(acc_id_list).mean().item())
#         acc_id_std   = float(torch.tensor(acc_id_list).std(unbiased=False).item())
#         ece_id_mean  = float(torch.tensor(ece_id_list).mean().item())
#         ece_id_std   = float(torch.tensor(ece_id_list).std(unbiased=False).item())

#         acc_ood_mean = float(torch.tensor(acc_ood_list).mean().item())
#         acc_ood_std  = float(torch.tensor(acc_ood_list).std(unbiased=False).item())
#         ece_ood_mean = float(torch.tensor(ece_ood_list).mean().item())
#         ece_ood_std  = float(torch.tensor(ece_ood_list).std(unbiased=False).item())

#         # Record ID metrics in one list and OOD in another
#         row_id = {
#             "model_type":        model_name,
#             "prior_precision":   pp,
#             "temperature":       temp,
#             "hessian_structure": hess,
#             "pred_type":         pred_type,
#             "n_samples":         n_samp if pred_type == 'nn' else None,
#             "link_approx":       link,
#             "joint":             joint_flag,
#             "diagonal_output":   diag_out,
#             "n_subset":          n_sub if hess == 'gp' else None,
#             "acc_id_mean":       acc_id_mean,
#             "acc_id_std":        acc_id_std,
#             "ece_id_mean":       ece_id_mean,
#             "ece_id_std":        ece_id_std,
#         }
#         results_id.append(row_id)

#         row_ood = {
#             "model_type":         model_name,
#             "prior_precision":    pp,
#             "temperature":        temp,
#             "hessian_structure":  hess,
#             "pred_type":          pred_type,
#             "n_samples":          n_samp if pred_type == 'nn' else None,
#             "link_approx":        link,
#             "joint":              joint_flag,
#             "diagonal_output":    diag_out,
#             "n_subset":           n_sub if hess == 'gp' else None,
#             "acc_ood_mean":       acc_ood_mean,
#             "acc_ood_std":        acc_ood_std,
#             "ece_ood_mean":       ece_ood_mean,
#             "ece_ood_std":        ece_ood_std,
#         }
#         results_ood.append(row_ood)

#         # Increment counter, print progress, and flush every 100 configs
#         config_counter += 1
#         print(f"[{config_counter}/{total_configs}] "
#               f"Finished {model_name} | hp = "
#               f"(pp={pp}, τ={temp}, hess={hess}, pred={pred_type}, "
#               f"n_samp={n_samp}, link={link}, joint={joint_flag}, "
#               f"diag_out={diag_out}, n_sub={n_sub})")

#         # Every 100 configurations, write partial CSVs to disk
#         if config_counter % 100 == 0:
#             df_id  = pd.DataFrame(results_id)
#             df_ood = pd.DataFrame(results_ood)
#             df_id.to_csv(os.path.join(project_root, "last_layer_laplace_metrics_id.csv"), index=False)
#             df_ood.to_csv(os.path.join(project_root, "last_layer_laplace_metrics_ood.csv"), index=False)
#             print(f"  → Flushed {config_counter} rows to CSV files.")

# # After all loops, write final CSVs
# df_id_final  = pd.DataFrame(results_id)
# df_ood_final = pd.DataFrame(results_ood)
# df_id_final.to_csv(os.path.join(project_root, "last_layer_laplace_metrics_id.csv"), index=False)
# df_ood_final.to_csv(os.path.join(project_root, "last_layer_laplace_metrics_ood.csv"), index=False)
# print("Saved final CSVs for ID and OOD metrics.")

import pandas as pd
import os

results_id = []
results_ood = []

arch_model_lists = {
    "lenet": lenet_models,
    "mlp":   mlp_models,
}

# Prepare an output directory
output_dir = Path("/kaggle/working/laplace_outputs")
output_dir.mkdir(parents=True, exist_ok=True)

# Freeze gradients
for model_list in arch_model_lists.values():
    for m in model_list:
        m.requires_grad_(False)

# Compute total number of (architecture × hyperparameter) configurations
num_archs = len(arch_model_lists)
num_hps   = len(hp_grid)
total_configs = num_archs * num_hps
config_counter = 0

for model_name, models_list in arch_model_lists.items():
    for config in hp_grid:
        pp         = config['prior_precision']
        temp       = config['temperature']
        hess       = config['hessian_structure']
        link       = config['link_approx']
        joint_flag = config['joint']
        diag_out   = config['diagonal_output']

        # Collect per‐seed metrics (though we always use analytic GLM, so each seed is just one run)
        acc_id_list  = []
        ece_id_list  = []
        acc_ood_list = []
        ece_ood_list = []

        # Loop over seeds
        for base_model in models_list:
            # Copy pretrained model into a fresh instance
            model_copy = type(base_model)().to(device)
            model_copy.load_state_dict(base_model.state_dict())
            model_copy.eval()

            # Build a Laplace object (analytic GLM on last layer)
            la = Laplace(
                model_copy,
                'classification',
                subset_of_weights='last_layer',
                hessian_structure=hess,
                prior_precision=pp,
                temperature=temp,
                enable_backprop=False
            )

            # Fit on the original MNIST training set
            la.fit(mnist_train_loader)

            # -- Evaluate ID (MNIST test) --
            all_probs_id  = []
            all_labels_id = []
            skip_this_hp = False

            with torch.no_grad():
                for X, y in mnist_test_loader:
                    X, y = X.to(device), y.to(device)
                    try:
                        probs = la(
                            X,
                            pred_type='glm',
                            link_approx=link,
                            n_samples=1,
                            diagonal_output=diag_out
                        )
                    except (RuntimeError, torch._C._LinAlgError):
                        # If Hessian not PD → skip entire ID/OOD for this seed+HP
                        skip_this_hp = True
                        break

                    all_probs_id.append(probs.cpu())
                    all_labels_id.append(y.cpu())

            if skip_this_hp or len(all_probs_id)==0:
                # non‐PD Hessian → record NaN
                acc_id = float('nan')
                ece_id = float('nan')
            else:
                all_probs_id  = torch.cat(all_probs_id, dim=0)
                all_labels_id = torch.cat(all_labels_id, dim=0)
                _, preds_id   = torch.max(all_probs_id, dim=1)
                acc_id        = (preds_id == all_labels_id).float().mean().item()
                ece_id        = compute_ece(all_probs_id, all_labels_id)

            # -- Evaluate OOD (FashionMNIST) --
            all_probs_ood  = []
            all_labels_ood = []
            if skip_this_hp:
                acc_ood = float('nan')
                ece_ood = float('nan')
            else:
                with torch.no_grad():
                    for X, y in fashionmnist_loader:
                        X, y = X.to(device), y.to(device)
                        try:
                            probs = la(
                                X,
                                pred_type='glm',
                                link_approx=link,
                                n_samples=1,
                                diagonal_output=diag_out
                            )
                        except (RuntimeError, torch._C._LinAlgError):
                            skip_this_hp = True
                            break

                        all_probs_ood.append(probs.cpu())
                        all_labels_ood.append(y.cpu())

                if skip_this_hp or len(all_probs_ood)==0:
                    acc_ood = float('nan')
                    ece_ood = float('nan')
                else:
                    all_probs_ood  = torch.cat(all_probs_ood, dim=0)
                    all_labels_ood = torch.cat(all_labels_ood, dim=0)
                    _, preds_ood   = torch.max(all_probs_ood, dim=1)
                    acc_ood        = (preds_ood == all_labels_ood).float().mean().item()
                    ece_ood        = compute_ece(all_probs_ood, all_labels_ood)

            # Collect this seed’s results
            acc_id_list.append(acc_id)
            ece_id_list.append(ece_id)
            acc_ood_list.append(acc_ood)
            ece_ood_list.append(ece_ood)

            # Clean up GPU memory for this seed
            del la, model_copy, all_probs_id, all_probs_ood
            torch.cuda.empty_cache()

        # Compute nan‐aware mean and std via module‐level torch.nanmean / torch.nanstd
        t_acc_id  = torch.tensor(acc_id_list,  dtype=torch.float32, device='cpu')
        valid_acc_id = t_acc_id[~torch.isnan(t_acc_id)]
        acc_id_mean = valid_acc_id.mean().item()
        acc_id_std  = valid_acc_id.std(unbiased=False).item()
        
        t_ece_id  = torch.tensor(ece_id_list,  dtype=torch.float32, device='cpu')
        valid_ece_id = t_ece_id[~torch.isnan(t_ece_id)]
        ece_id_mean = valid_ece_id.mean().item()
        ece_id_std  = valid_ece_id.std(unbiased=False).item()
        
        t_acc_ood = torch.tensor(acc_ood_list, dtype=torch.float32, device='cpu')
        valid_acc_ood = t_acc_ood[~torch.isnan(t_acc_ood)]
        acc_ood_mean = valid_acc_ood.mean().item()
        acc_ood_std  = valid_acc_ood.std(unbiased=False).item()
        
        t_ece_ood = torch.tensor(ece_ood_list, dtype=torch.float32, device='cpu')
        valid_ece_ood = t_ece_ood[~torch.isnan(t_ece_ood)]
        ece_ood_mean = valid_ece_ood.mean().item()
        ece_ood_std  = valid_ece_ood.std(unbiased=False).item()
        
        # Record one row for ID and one for OOD
        results_id.append({
            "model_type":        model_name,
            "prior_precision":   pp,
            "temperature":       temp,
            "hessian_structure": hess,
            "link_approx":       link,
            "joint":             joint_flag,
            "diagonal_output":   diag_out,
            "acc_id_mean":       acc_id_mean,
            "acc_id_std":        acc_id_std,
            "ece_id_mean":       ece_id_mean,
            "ece_id_std":        ece_id_std,
        })

        results_ood.append({
            "model_type":        model_name,
            "prior_precision":   pp,
            "temperature":       temp,
            "hessian_structure": hess,
            "link_approx":       link,
            "joint":             joint_flag,
            "diagonal_output":   diag_out,
            "acc_ood_mean":      acc_ood_mean,
            "acc_ood_std":       acc_ood_std,
            "ece_ood_mean":      ece_ood_mean,
            "ece_ood_std":       ece_ood_std,
        })

        config_counter += 1
        print(f"[{config_counter}/{total_configs}] "
              f"{model_name} | pp={pp}, τ={temp}, hess={hess}, link={link}, "
              f"diag_out={diag_out}  →  "
              f"ID_acc={acc_id_mean:.4f}±{acc_id_std:.4f}, ID_ECE={ece_id_mean:.4f}±{ece_id_std:.4f}  |  "
              f"OOD_acc={acc_ood_mean:.4f}±{acc_ood_std:.4f}, OOD_ECE={ece_ood_mean:.4f}±{ece_ood_std:.4f}")

       

        # Flush every 50 configurations (optional)
        if config_counter % 20 == 0:
            df_id  = pd.DataFrame(results_id)
            df_ood = pd.DataFrame(results_ood)
            df_id.to_csv(
                os.path.join(output_dir / "last_layer_laplace_metrics_id.csv"),
                index=False
            )
            df_ood.to_csv(
                os.path.join(output_dir / "last_layer_laplace_metrics_ood.csv"),
                index=False
            )
            print(f"  → Flushed {config_counter} rows to CSV.")

# ───────────────────────────────────────────────────────────────────
# 3) Save the final CSVs
# ───────────────────────────────────────────────────────────────────
df_id_final  = pd.DataFrame(results_id)
df_ood_final = pd.DataFrame(results_ood)
df_id_final.to_csv(output_dir / "last_layer_laplace_metrics_id.csv", index=False)
df_ood_final.to_csv(output_dir / "last_layer_laplace_metrics_ood.csv", index=False)
print("Saved final CSVs for ID and OOD metrics.")

[1/128] lenet | pp=1e-06, τ=0.5, hess=diag, link=probit, diag_out=False  →  ID_acc=0.9787±0.0017, ID_ECE=0.0024±0.0001  |  OOD_acc=0.1077±0.0062, OOD_ECE=0.5556±0.0722


KeyboardInterrupt: 

## 6. Plotting extended sensitivity results