In [1]:
import math
import torch
import hess
import hess.utils as utils
import hess.nets
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from hess.utils import get_hessian_eigs
import matplotlib.pyplot as plt
from gpytorch.utils.lanczos import lanczos_tridiag, lanczos_tridiag_to_diag

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
model = Net()
model.load_state_dict(torch.load("./model_dict.pt", map_location=torch.device('cpu')))

<All keys matched successfully>

## Is it trained?

In [4]:
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='/home/wesley/Documents/datasets/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='/home/wesley/Documents/datasets/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
dataiter = iter(testloader)

### it's definitely sort of trained.

## Let's compute some eigenvectors

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [7]:
model(next(iter(trainloader))[0])

tensor([[ 1.2995, -0.1033, -0.1400, -0.4831,  0.1963, -1.8073, -1.8516, -2.2150,
          3.1614,  1.1122],
        [-1.0028, -1.4596,  0.3556,  1.7450, -0.3046,  2.1966,  0.3917, -0.0745,
         -1.2852, -0.5849],
        [ 3.9742,  5.1580, -2.4446, -2.5135, -4.5396, -4.6319, -6.4897, -5.0900,
          8.2268,  3.9972],
        [-1.0439, -2.8318,  0.0641,  1.0225,  2.6743,  1.1772, -0.8312,  2.6130,
         -2.6329, -1.7608]], grad_fn=<AddmmBackward>)

In [9]:
evals = get_hessian_eigs(loss=criterion,
                     model=model, use_cuda=True, n_eigs=10,
                     loader=trainloader)

norm of hvp is:  tensor(1800.8514, device='cuda:0')
norm of hvp is:  tensor(113037.4844, device='cuda:0')
norm of hvp is:  tensor(119397.0938, device='cuda:0')
norm of hvp is:  tensor(105285.0625, device='cuda:0')
norm of hvp is:  tensor(92617.9688, device='cuda:0')
norm of hvp is:  tensor(121537.2656, device='cuda:0')
norm of hvp is:  tensor(74416.7344, device='cuda:0')
norm of hvp is:  tensor(65032.6367, device='cuda:0')
norm of hvp is:  tensor(71717.9297, device='cuda:0')
norm of hvp is:  tensor(60045.8438, device='cuda:0')


In [10]:
evals

tensor([1.0000e+00, 1.3527e+01, 5.2005e+03, 2.3199e+04, 3.8299e+04, 6.8720e+04,
        9.4441e+04, 1.0223e+05, 1.5229e+05, 1.7249e+05], device='cuda:0')