# Experimentation of using PyHessian

Note: If using mps, you need to change line 62 in hessian.py to "mps" instead of "cpu".

In [5]:
import torch
import os
from pyhessian import hessian  # Hessian computation
# Import the helper classes from the src package
from src import CifarLoader, CifarNet

# Minimal example: create loaders and a model instance
test_loader = CifarLoader("cifar10", train=False, batch_size=500)
train_loader = CifarLoader("cifar10", train=True, batch_size=500, aug=dict(flip=True, translate=2))
model = CifarNet().to(torch.device("mps"))

# Load weights
ckpt_path = os.path.join('..', 'model_weights', 'model_epoch_1.pt')
model.load_state_dict(torch.load(ckpt_path, map_location='mps')['state_dict'])

# Example: compute Hessian for one batch (replace with actual inputs/targets)
inputs, targets = next(iter(train_loader))
crit = torch.nn.CrossEntropyLoss()
hessian_comp = hessian(model, crit, data=(inputs, targets), cuda=False)


In [6]:
eig_vals, eig_vecs = hessian_comp.eigenvalues(top_n=10)

In [3]:
eig_vals

[273.8551940917969,
 196.0106201171875,
 144.30859375,
 129.8615264892578,
 101.46768188476562,
 75.86442565917969,
 65.5568618774414,
 54.035369873046875,
 43.73713684082031,
 5940080.5]