In [4]:
import torch
from torch.autograd import grad

def layer_by_layer_hess_vec(dataloader, model, cuda=True, bn_train_mode=False):
    # Prepare the model
    model.eval()
    if bn_train_mode:
        # Custom function to set batch normalization layers to train mode
        model.apply(_bn_train_mode)
    
    # Initialize a dictionary to store Hessian-vector products for each layer
    layerwise_hvp = {}

    # Iterate through each parameter layer in the model
    for name, parameter in model.named_parameters():
        if parameter.requires_grad:
            # Initialize a vector of ones with the same shape as the parameter
            vec = torch.ones_like(parameter)
            
            # Ensure vector is on the same device as the model
            if cuda:
                vec = vec.cuda()
            
            # Zero gradients in the model
            model.zero_grad()
            
            # Compute the loss for the given data loader
            total_loss = 0
            for batch_idx, batch in enumerate(dataloader):
                input_ids = batch["input_ids"].to("cuda" if cuda else "cpu")
                
                # Forward pass
                outputs = model(input_ids=input_ids, labels=input_ids)
                loss = outputs.loss
                if loss.dim() > 0:  # Check if loss is not scalar
                    loss = loss.mean()
                total_loss += loss
            
            # Normalize the loss
            total_loss /= len(dataloader)
            
            # Compute gradients with respect to the target parameter
            grad_loss = grad(total_loss, parameter, create_graph=True)[0]
            
            # Compute the Hessian-vector product for the current parameter
            hvp = grad(grad_loss, parameter, grad_outputs=vec)[0]
            
            # Store the computed Hessian-vector product
            layerwise_hvp[name] = hvp.detach()  # Detach to avoid saving in the computation graph
    
    return layerwise_hvp



In [3]:
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Example model (for demonstration)
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Load a sample dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Initialize the model and criterion
model = SimpleNet()
criterion = nn.CrossEntropyLoss()

# Get a single batch from the DataLoader
inputs, labels = next(iter(train_loader))

# Compute the layer-wise Hessian
layer_hessians = compute_layerwise_hessian(model, inputs, labels, criterion)

# Example: print the shape of the Hessian for the first layer
print(layer_hessians['fc1.weight'].shape)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.