# Metrics on saved weights

This notebook can be run on the existing `model_weights` directory which is populated with the model weights after each step of the training.

### Imports and setup

In [1]:
from src.metrics import analyze_hessian_eigenvalues, analyze_hessian_stable_rank
from src.train_model import CifarLoader, CifarNet
import torch
import os
import matplotlib.pyplot as plt
import numpy as np
import time

# Define constants
weights_path = "../model_weights"  # Path to a specific .pt file
device = "cpu"
weight_files = sorted(
    os.listdir(weights_path),
    key=lambda x: int(x.split("_")[-1].split(".")[0])  # Extract epoch number
)

# Load a batch
train_loader = CifarLoader("cifar10", train=True, batch_size=200)
batch = next(iter(train_loader))

Using device: mps


# Run for each epoch

In [2]:
stable_ranks = []
sharpnesses = []

traces = []
densities = []
for weight_file in weight_files:
    weight_path = os.path.join(weights_path, weight_file)

    start_time = time.time()

    # Load the model at this checkpoint
    model = CifarNet().to(torch.device(device))
    checkpoint = torch.load(weight_path, map_location=torch.device(device))
    if "state_dict" in checkpoint:
        model.load_state_dict(checkpoint["state_dict"])
    else:
        model.load_state_dict(checkpoint)
    model.eval() 

    # Compute the metrics
    stable_rank, sharpness = analyze_hessian_stable_rank(model, batch, device=device)
    top_eigenvalues, trace = analyze_hessian_eigenvalues(model, batch, ev_num=10, device=device)
    traces.append(trace)
    stable_ranks.append(stable_rank)
    sharpnesses.append(sharpness)
    elapsed_time = time.time() - start_time
    print(f"Processed {weight_file} in {elapsed_time:.2f} seconds")

# Create plot with three subplots
epochs = list(range(len(stable_ranks)))

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16, 5))

# Plot trace
ax1.plot(epochs, traces, marker="o", linewidth=2, markersize=6)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Trace")
ax1.set_title("Hessian Trace vs Epoch")
ax1.grid(True, alpha=0.3)

# Plot stable rank
ax2.plot(epochs, stable_ranks, marker="o", linewidth=2, markersize=6)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Stable Rank")
ax2.set_title("Hessian Stable Rank vs Epoch")
ax2.grid(True, alpha=0.3)

# Plot sharpness (top eigenvalue)
ax3.plot(epochs, sharpnesses, marker="s", linewidth=2, markersize=6, color="orange")
ax3.set_xlabel("Epoch")
ax3.set_ylabel("Sharpness")
ax3.set_title("Sharpness vs Epoch")
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig("hessian_metrics.png", dpi=150, bbox_inches="tight")
print("Plot saved to hessian_metrics.png")

Hessian dimension: (146136, 146136)
Total parameters: 146136
Computing squared Frobenius norm using 100 matrix-vector products (rademacher distribution)...
Squared Frobenius norm (||H||_FÂ²): 3.224394e+02
Computing top 1 eigenvalues...


KeyboardInterrupt: 