In [164]:
# Install pyhessian (if not already installed)
import torch
import torch.nn as nn
import torch.optim as optim
from pyhessian import hessian
from pyhessian.utils import get_params_grad

# Define a simple MLP (same as above)
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size, bias=False)
        self.fc3 = nn.Linear(hidden_size, output_size, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc3(x)
        return x

# Example usage
input_size = 784  # For example, for MNIST dataset
hidden_size = 1024
output_size = 10

model = MLP(input_size, hidden_size, output_size).cuda()

# Dummy data
x = torch.randn(1024, input_size).cuda()
y = torch.randint(0, output_size, (1024,)).cuda()
inputs = x
targets = y
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Forward pass
output = model(x)
loss = criterion(output, y)

# Backward pass
optimizer.zero_grad()
loss.backward()
# optimizer.step()

# Calculate the Hessian and its top eigenvalues
hessian_comp = hessian(model, criterion, data=(x, y), cuda=True)
top_eigenvalues, top_eigenvectors = hessian_comp.eigenvalues(top_n=4)

print("Top 10 eigenvalues of the Hessian matrix:", top_eigenvalues)

# Compute the gradient
gradients = get_params_grad(model)

Top 10 eigenvalues of the Hessian matrix: [6.045412063598633, 5.936735153198242, 5.620837688446045, 5.510869979858398]


In [188]:
def get_projected_gradients(gradient, eigenvectors):
    g_proj = torch.tensor([gradient@eigenv for eigenv in eigenvectors]).to("cuda")
    return g_proj
    
def process_eigenvectors(eigenvectors):
    es = []
    for e in eigenvectors:
        ne = torch.cat([v.flatten() for v in e], dim=0) # get each eigenvector of size n_parameters
        es.append(ne)
    return es
def top_k_hessian_alignment(projected_gradients, gradients, k):
    projected_gradients = projected_gradients[:k]
    dir_sharp = projected_gradients@projected_gradients / gradients.norm()**2
    
    return dir_sharp.item()

def process_gradients(grads):
    grads = [g.flatten() for g in grads]
    return torch.cat(grads, dim=0)

In [200]:
ks = [1, 2, 5]
kmax = max(ks)
hessian_comp = hessian(model, criterion, data=(inputs, targets), cuda=True)
gs = process_gradients(hessian_comp.gradsH)
top_eigenvalues, top_eigenvectors = hessian_comp.eigenvalues(top_n=kmax, maxIter=1000, tol=1e-6) 
top_eigenvectors = process_eigenvectors(top_eigenvectors)
proj_g = get_projected_gradients(gs, top_eigenvectors)

In [201]:
top_eigenvalues

[6.134904384613037,
 5.89288854598999,
 5.671705722808838,
 5.594947814941406,
 5.48469352722168]

In [193]:
top_eigenvectors

[tensor([9.5275e-05, 1.3720e-05, 6.4385e-05,  ..., 1.0472e-03, 1.1759e-03,
         9.7791e-04], device='cuda:0'),
 tensor([ 8.2209e-05,  8.1544e-05, -1.3706e-04,  ...,  1.2942e-03,
          1.5109e-03,  1.1640e-03], device='cuda:0'),
 tensor([ 2.9659e-05,  5.5363e-05, -1.0326e-04,  ..., -1.8919e-04,
         -1.0706e-04, -1.4825e-04], device='cuda:0'),
 tensor([-1.4612e-05,  1.6542e-05, -2.6401e-05,  ...,  1.1904e-03,
          1.2793e-03,  1.1327e-03], device='cuda:0'),
 tensor([ 5.0928e-05,  3.2958e-05, -8.1254e-05,  ...,  1.1652e-03,
          1.2562e-03,  1.1057e-03], device='cuda:0')]

In [184]:
top_k_hessian_alignment(proj_g, gs, kmax)

0.026163375005126