In [1]:
import math
import torch
import hess
import hess.utils as utils
import hess.nets
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from hess.utils import get_hessian_eigs
import matplotlib.pyplot as plt
from gpytorch.utils.lanczos import lanczos_tridiag, lanczos_tridiag_to_diag

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
model = Net()
model.load_state_dict(torch.load("./model_dict.pt", map_location=torch.device('cpu')))

<All keys matched successfully>

## Is it trained?

In [4]:
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='/home/wesley/Documents/datasets/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='/home/wesley/Documents/datasets/', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
dataiter = iter(testloader)

### it's definitely sort of trained.

## Let's compute some eigenvectors

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)


In [7]:
model(next(iter(trainloader))[0])

tensor([[-2.7671, -2.0388,  0.4264,  3.9317, -1.4709,  6.1992, -1.8596,  1.1216,
         -3.4261, -0.6860],
        [ 0.1282, -1.8925,  0.5647,  1.3492, -0.4189,  1.5355,  0.6183,  0.9556,
         -2.7404, -0.9264],
        [-0.9680,  8.5552, -3.8322, -3.5463, -4.4402, -2.6245, -1.0444, -4.2101,
          1.2005,  5.8649],
        [-2.7179, -1.7871,  0.3432,  4.2173, -2.1845,  4.0973,  1.9996, -0.7459,
         -2.6026, -0.1680]], grad_fn=<AddmmBackward>)

In [8]:
evals, evecs = get_hessian_eigs(loss=criterion,
                     model=model, use_cuda=True, n_eigs=2,
                     loader=trainloader)

norm of hvp is:  tensor(1061.5405, device='cuda:0')
norm of hvp is:  tensor(82083.6016, device='cuda:0')


In [9]:
print(evals)
print(evecs)

tensor(32.0178, device='cuda:0')
tensor(70626.0234, device='cuda:0')


## Well that's broken.

## What's up with HVP and Lanczos?

In [10]:
total_pars = sum(p.numel() for p in model.parameters())

In [11]:
def hvp(rhs):
    padded_rhs = torch.zeros(total_pars, rhs.shape[-1],
                             device=rhs.device, dtype=rhs.dtype)
    print(padded_rhs.shape)
    padded_rhs = utils.unflatten_like(padded_rhs.t(), model.parameters())
    utils.eval_hess_vec_prod(padded_rhs, net=model,
                       criterion=criterion, inputs=None,
                       targets=None, dataloader=trainloader, use_cuda=False)
    full_hvp = utils.gradtensor_to_tensor(model, include_bn=True)
    return full_hvp.unsqueeze(-1)


In [12]:
e1 = torch.zeros(total_pars)
e1[1] = 1.
test = hvp(e1.unsqueeze(-1))

torch.Size([62006, 1])


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same

In [None]:
data = next(iter(trainloader))[0]
dtype = data.dtype
device = data.device
qmat, tmat = lanczos_tridiag(hvp, 3, dtype=dtype,
                          device=device, matrix_shape=(total_pars,
                          total_pars))


In [None]:
tmat

## That's broken too :)