In [None]:
import torch                                                                                          
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms, datasets
import numpy as np
import time 

class conv_net(nn.Module):
    def __init__(self):
        super(conv_net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        return x
    
def get_hessian(batch_loss, model):
    t = time.time()

    grads = torch.autograd.grad(batch_loss, model.parameters(), create_graph=True)
    grads = torch.cat([x.view(-1) for x in grads])
    hessian = torch.zeros((grads.shape[0], grads.shape[0]))
    i = 0
    for grad_t in grads:
        for g in grad_t.reshape(-1):
            hess = torch.autograd.grad(g, model.parameters(), retain_graph = True)
            hess = torch.cat([h.reshape(-1) for h in hess])
            hessian[i,:] = hess
            i+=1

    print(f"Time elapsed:{time.time() - t}")
    return hessian

In [None]:
input_data = torch.randn(1,1,28,28)
label_data = torch.Tensor([1]).type(torch.LongTensor)

model = conv_net()
model.train()
criterion = nn.CrossEntropyLoss()

In [None]:
outputs = model(input_data)
batch_loss = criterion(outputs, label_data)

In [None]:
get_hessian(batch_loss, model)