In [1]:
%pip install numpy seaborn matplotlib pandas pillow tqdm scikit-learn torch torchvision backpack-for-pytorch opt_einsum

[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.3.1[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import os
import sys

N_up = 2
nb_dir = '/'.join(os.getcwd().split('/')[:-N_up])
if nb_dir not in sys.path:
    sys.path.append(nb_dir)
N_up = 1
nb_dir = '/'.join(os.getcwd().split('/')[:-N_up])
if nb_dir not in sys.path:
    sys.path.append(nb_dir)

In [3]:
import os
import time
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Step #0: Define Experimental Setup

In [None]:
from pathlib import Path

# general
device = "cuda:0"
model_name = "resnet18"
dataset = "MNIST"

train_batch_size = 256 
laplace_batch_size = 64 
test_batch_size = 64

n_test_data = None
rotations = [0, 30, 60, 90, 120, 150, 180]
n_out = 10
loss = "cross_entropy"

# paths
data_dir = "./data"
root_dir = Path(f"./{model_name}_{dataset}")
ggn_dir = root_dir / "ggn.pt"

# subnetwork selection
n_weights_subnet = 5000
subnet_selection = "snr" # "snr", "magnitude", "min-wass", "random"
layer_weight=None
methods=['snr_5K']

# prediction
pred_lambda = 42.

In [46]:
start_time1=time.time()

# Step #1: Train or Load Model

In [45]:
from src.scripts.train_classification import train_loop
from src.utils import list_batchnorm_layers, get_n_params, model_to_device, instantiate_model
start_time=time.time()

model = model_to_device(instantiate_model(model_name, dataset, 0.1), device)
bn_layers = list_batchnorm_layers(model)
print('Nparams:', get_n_params(model, bn_layers))

gpu = 0
train_loop(model, dname=dataset, data_dir=data_dir, epochs=20, workers=4, gpu=gpu, resume=str(root_dir / 'model_best.pth.tar'),
            weight_decay=1e-4, save_dir=str(root_dir), milestones=[10, 17], MC_samples=1, batch_size=train_batch_size)

end_time=time.time()
mins_elapsed= round((end_time-start_time)/60, 2)
print(f"{mins_elapsed} mins elapsed...")

Nparams: 11163200
Use GPU: 0 for training
=> loading checkpoint 'resnet18_MNIST/model_best.pth.tar'
=> loaded checkpoint 'resnet18_MNIST/model_best.pth.tar' (epoch 22)
=> found progress file at 'resnet18_MNIST/stats_array.pkl'
=> Loaded progress file at 'resnet18_MNIST/stats_array.pkl'
Ntrain: 60000, Nval: 10000
0.0 mins elapsed...


# Step #2: Select Subnetwork

In [47]:
from src.datasets.image_loaders import get_image_loader
from src.utils import print_nonzeros
from src.masking.masking import random_mask, smallest_magnitude_mask, wasserstein_mask, snr_mask

start_time=time.time()
# compute subnetwork mask
if n_weights_subnet == None:
    mask = None

elif subnet_selection == "random":
    mask, index_mask, weight_score_vec = random_mask(model, bn_layers, n_weights_subnet, device=device)

elif subnet_selection == "min-wass":
    train_loader = get_image_loader(dataset, batch_size=train_batch_size, cuda=True, workers=4, 
                                    distributed=False, data_dir=data_dir)[1]

    mask, index_mask, weight_score_vec = wasserstein_mask(model, 
                                                          bn_layers, 
                                                          n_weights_subnet, 
                                                          train_loader, 
                                                          device,
                                                          layer_weight=layer_weight)

elif subnet_selection == "snr":
    train_loader = get_image_loader(dataset, batch_size=train_batch_size, cuda=True, workers=4,
                                    distributed=False, data_dir=data_dir)[1]

    mask, index_mask, weight_score_vec = snr_mask(model,
                                                  bn_layers,
                                                  n_weights_subnet,
                                                  train_loader,
                                                  device)

elif subnet_selection == "magnitude":
    mask, index_mask, weight_score_vec = smallest_magnitude_mask(model, bn_layers, n_weights_subnet)
    
else:
    raise NotImplementedError("Supported subnetwork selection methods: snr, random, min-wass, magnitude.")

if mask is not None:
    # print mask information
    print_nonzeros(mask)

end_time=time.time()
mins_elapsed= round((end_time-start_time)/60, 2)
print(f"{mins_elapsed} mins elapsed...")

Ntrain: 60000, Nval: 10000


100%|██████████| 256/256 [36:31<00:00,  8.56s/it]


conv1.0.weight                      | remaining =       0 /     576 (  0.00%) | pruned =     576 | shape = torch.Size([64, 1, 3, 3])
layer_list.0.conv1.weight           | remaining =      35 /   36864 (  0.09%) | pruned =   36829 | shape = torch.Size([64, 64, 3, 3])
layer_list.0.conv2.weight           | remaining =      50 /   36864 (  0.14%) | pruned =   36814 | shape = torch.Size([64, 64, 3, 3])
layer_list.1.conv1.weight           | remaining =      27 /   36864 (  0.07%) | pruned =   36837 | shape = torch.Size([64, 64, 3, 3])
layer_list.1.conv2.weight           | remaining =      25 /   36864 (  0.07%) | pruned =   36839 | shape = torch.Size([64, 64, 3, 3])
layer_list.2.conv1.weight           | remaining =      79 /   73728 (  0.11%) | pruned =   73649 | shape = torch.Size([128, 64, 3, 3])
layer_list.2.conv2.weight           | remaining =      74 /  147456 (  0.05%) | pruned =  147382 | shape = torch.Size([128, 128, 3, 3])
layer_list.2.downsample.0.weight    | remaining =       8 / 

In [32]:
# mask, index_mask, weight_score_vec = wasserstein_mask(model, bn_layers, n_weights_subnet, train_loader, device, weight_score_vec=weight_score_vec)

# Step #3: Do Linearized Laplace Inference

In [48]:
from src.laplace.laplace import Laplace
from src.datasets.image_loaders import get_image_loader

start_time=time.time()
# instantiate Laplace model
laplace_dir = root_dir / f"laplace.pth.tar"
laplace_model = Laplace(model, 
                        mask=mask, 
                        index_mask=index_mask, 
                        save_path=laplace_dir, 
                        device=device, 
                        loss=loss, 
                        n_out=n_out)

# load or fit Hessian approximation
if ggn_dir.exists():
    print("Loading GGN from disk...")
    laplace_model.H = torch.load(ggn_dir)

elif laplace_dir.exists():
    print("Loading Laplace model from disk...")
    laplace_model.load()

else:
    print(f"Computing Hessian/GGN...")
    train_loader = get_image_loader(dataset, batch_size=laplace_batch_size, cuda=True, workers=2, distributed=False, data_dir=data_dir)[1]
    laplace_model.fit_laplace(train_loader)

end_time=time.time()
mins_elapsed= round((end_time-start_time)/60, 2)
print(f"{mins_elapsed} mins elapsed...")

Loading Laplace model from disk...
0.01 mins elapsed...


# Step #4: Make OOD Prediction

In [49]:
import numpy as np

results_ood_path = root_dir / "results_ood.npy"

if results_ood_path.exists():
    results_dict_ood = np.load(results_ood_path, allow_pickle=True).item()
else:
    results_dict_ood = {}

In [50]:
from src.evaluation.evaluate_ood import evaluate_laplace_ood

start_time=time.time()

# compute error, log-likelihood, Brier score and ECE on shifted test data
for method in methods:
    if method in results_dict_ood:
        continue
    elif method not in results_dict_ood:
        results_dict_ood[method] = {}

    print(f"Computing predictions for {method}...")

    results_dict_ood[method] = evaluate_laplace_ood(laplace_model, dataset, data_dir, target_dataset="Fashion",
                                                     batch_size=test_batch_size, λ=pred_lambda, n_test_data=n_test_data)

# save result dictionary
np.save(results_ood_path, results_dict_ood)

end_time=time.time()
mins_elapsed= round((end_time-start_time)/60, 2)
print(f"{mins_elapsed} mins elapsed...")

Computing predictions for snr_5K...
Ntrain: 60000, Nval: 10000
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]



Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]

6.86 mins elapsed...


  reference = np.array([expanded_preds[bin_idxs == nbin].mean() for nbin in range(n_bins)])
  ret = ret.dtype.type(ret / rcount)


# Rotation-wise Prediction

In [51]:
results_path = root_dir / "results.npy"
if results_path.exists():
    results_dict = np.load(results_path, allow_pickle=True).item()
else:
    results_dict = {}

In [52]:
import numpy as np

from src.evaluation.evaluate_laplace import evaluate_laplace
from src.evaluation.evaluate_baselines import evaluate_map

start_time=time.time()

# compute error, log-likelihood, Brier score and ECE on shifted test data
for method in methods:
    for rot in rotations:
        if method in results_dict and rot in results_dict[method]:
            continue
        elif method not in results_dict:
            results_dict[method] = {}

        print(f"Computing predictions for {method} at rotation={rot}...")
        
        if method == "MAP":
            results_dict[method][rot] = evaluate_map(model, dataset, data_dir, device, loss, corruption=None,
                                                    rotation=rot, batch_size=test_batch_size, n_test_data=n_test_data)
        else:
            results_dict[method][rot] = evaluate_laplace(laplace_model, dataset, data_dir, corruption=None, rotation=rot, 
                                                        batch_size=test_batch_size, λ=pred_lambda, n_test_data=n_test_data)

# save result dictionary
np.save(results_path, results_dict)

end_time=time.time()
mins_elapsed= round((end_time-start_time)/60, 2)
print(f"{mins_elapsed} mins elapsed...")

Computing predictions for snr_5K at rotation=0...
Ntrain: 60000, Nval: 10000
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]



Computing predictions for snr_5K at rotation=30...
Computing covariance matrix...


  reference = np.array([expanded_preds[bin_idxs == nbin].mean() for nbin in range(n_bins)])
  ret = ret.dtype.type(ret / rcount)


  0%|          | 0/157 [00:00<?, ?it/s]

Computing predictions for snr_5K at rotation=60...
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]

Computing predictions for snr_5K at rotation=90...
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]

Computing predictions for snr_5K at rotation=120...
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]

Computing predictions for snr_5K at rotation=150...
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]

Computing predictions for snr_5K at rotation=180...
Computing covariance matrix...


  0%|          | 0/157 [00:00<?, ?it/s]

23.96 mins elapsed...


In [53]:
import pandas as pd
# Create a list to hold the rows of the DataFrame
rows_list = []

for method, rotations_data in results_dict.items():
    for rotation, metrics in rotations_data.items():
        row_dict = {'method': method, 'rotation': rotation}
        row_dict.update(metrics)  # Add 'err', 'll', 'brier', 'ece'
        rows_list.append(row_dict)

df = pd.DataFrame(rows_list)

display(df)

Unnamed: 0,method,rotation,err,ll,brier,ece
0,MAP,0,0.0055,-0.017919,0.008828,0.001297
1,MAP,30,0.0916,-0.317241,0.142648,0.039864
2,MAP,60,0.6238,-3.433685,1.032552,0.459178
3,MAP,90,0.8357,-6.069594,1.399243,0.640741
4,MAP,120,0.7754,-6.582309,1.305964,0.593431
...,...,...,...,...,...,...
128,snr_5K,60,0.6209,-2.922373,0.990425,0.420624
129,snr_5K,90,0.8355,-5.115520,1.341547,0.597469
130,snr_5K,120,0.7765,-5.563317,1.261235,0.555133
131,snr_5K,150,0.6279,-5.174534,1.032490,0.446203


In [54]:
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"results/Results_{timestamp}.csv"
df.to_csv('Results_.csv') # always overwrites this one, contains latest
df.to_csv(filename, index=False) # keeps track of all files
print(f"Results saved as: {filename}") 

Results saved as: results/Results_20250505_161435.csv


In [56]:
print("Evaluation completed...")

Evaluation completed...


In [57]:
end_time1=time.time()
mins_elapsed= round((end_time1-start_time1)/60, 2)
print(f"Total {mins_elapsed} mins elapsed...")

Total 67.52 mins elapsed...


In [27]:
rows_list = []

for method, rotations_data in results_dict.items():
    for rotation, metrics in rotations_data.items():
        row_dict = {'method': method, 'rotation': rotation}
        row_dict.update(metrics)  # Add 'err', 'll', 'brier', 'ece'.'roc_auc'
        rows_list.append(row_dict)

df = pd.DataFrame(rows_list)

display(df)

Unnamed: 0,method,rotation,err,ll,brier,ece
0,min_wass_5K,0,0.004,-0.012224,0.005949,0.002052
1,min_wass_5K,30,0.086,-0.26064,0.126127,0.025123
2,min_wass_5K,60,0.6309,-2.978986,1.001125,0.433694
3,min_wass_5K,90,0.833,-5.068996,1.326677,0.594812
4,min_wass_5K,120,0.7751,-5.30901,1.253082,0.557794
5,min_wass_5K,150,0.6206,-5.59635,1.060961,0.477014
6,min_wass_5K,180,0.5729,-6.082541,1.013237,0.470645


In [29]:
from datetime import datetime
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"results/Results_{timestamp}.csv"
df.to_csv('Results_.csv') # always overwrites this one, contains latest
df.to_csv(filename, index=False) # keeps track of all files
print(f"Results saved as: {filename}")

Results saved as: results/Results_20250505_205557.csv
