In [1]:
import math
import torch
import hess
import hess.utils as utils
import hess.nets
import numpy as np
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

from compute_loss_surface import get_loss_surface
from min_max_evals import min_max_hessian_eigs
from hess.utils import get_hessian_eigs
import matplotlib.pyplot as plt
from gpytorch.utils.lanczos import lanczos_tridiag, lanczos_tridiag_to_diag

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [3]:
use_cuda =  torch.cuda.is_available()

model = Net()
# model.load_state_dict(torch.load("./model.pt"))
criterion = torch.nn.CrossEntropyLoss()

if use_cuda:
    torch.cuda.set_device(0)
    model = model.cuda()

In [4]:
transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='/datasets/cifar10/', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

Files already downloaded and verified


In [5]:
## Super Trainer ##
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(30):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        
        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0


[1,   100] loss: 2.029
[1,   200] loss: 1.756
[1,   300] loss: 1.623
[2,   100] loss: 1.480
[2,   200] loss: 1.423
[2,   300] loss: 1.404
[3,   100] loss: 1.318
[3,   200] loss: 1.293
[3,   300] loss: 1.287
[4,   100] loss: 1.222
[4,   200] loss: 1.193
[4,   300] loss: 1.221
[5,   100] loss: 1.159
[5,   200] loss: 1.160
[5,   300] loss: 1.142
[6,   100] loss: 1.104
[6,   200] loss: 1.107
[6,   300] loss: 1.105
[7,   100] loss: 1.060
[7,   200] loss: 1.050
[7,   300] loss: 1.058
[8,   100] loss: 1.022
[8,   200] loss: 1.011
[8,   300] loss: 1.006
[9,   100] loss: 0.989
[9,   200] loss: 0.962
[9,   300] loss: 0.988
[10,   100] loss: 0.964
[10,   200] loss: 0.949
[10,   300] loss: 0.934
[11,   100] loss: 0.903
[11,   200] loss: 0.929
[11,   300] loss: 0.927
[12,   100] loss: 0.912
[12,   200] loss: 0.888
[12,   300] loss: 0.899
[13,   100] loss: 0.866
[13,   200] loss: 0.877
[13,   300] loss: 0.884
[14,   100] loss: 0.826
[14,   200] loss: 0.867
[14,   300] loss: 0.847
[15,   100] loss: 0

KeyboardInterrupt: 

In [6]:
output = min_max_hessian_eigs(model, trainloader, criterion,
                              3, 3, use_cuda=use_cuda)

vec shape =  torch.Size([62006, 1])
padded shape =  torch.Size([62006, 1])
sliced shape =  torch.Size([62006, 1])
vec shape =  torch.Size([62006, 1])
padded shape =  torch.Size([62006, 1])
sliced shape =  torch.Size([62006, 1])
vec shape =  torch.Size([62006, 1])
padded shape =  torch.Size([62006, 1])
sliced shape =  torch.Size([62006, 1])
Pos Eigs Computed....

vec shape =  torch.Size([62006, 1])
padded shape =  torch.Size([62006, 1])
sliced shape =  torch.Size([62006, 1])
vec shape =  torch.Size([62006, 1])
padded shape =  torch.Size([62006, 1])
sliced shape =  torch.Size([62006, 1])
vec shape =  torch.Size([62006, 1])
padded shape =  torch.Size([62006, 1])
sliced shape =  torch.Size([62006, 1])
Neg Eigs Computed...



In [7]:
(pos_evals, pos_evecs, neg_evals, neg_evecs) = output

In [8]:
def loss_getter(model, dataloader, criterion, use_cuda=False):
    train_loss = 0.
    for dd, data in enumerate(dataloader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data
        if use_cuda:
            inputs, labels = inputs.cuda(), labels.cuda()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # print statistics
        train_loss += loss.item()

    return train_loss

def get_loss_surface(basis, model,
                    dataloader,
                    criterion,
                    rng=0.1, n_pts=25,
                    use_cuda=False):
    """
    note that loss should be a lambda function that just takes in the model!
    """

    start_pars = model.state_dict()
    ## get out the plane ##
    dir1, dir2 = get_plane(basis)

    ## init loss surface and the vector multipliers ##
    loss_surf = torch.zeros(n_pts, n_pts)
    vec_len = torch.linspace(-rng/2., rng/2., n_pts)

    ## loop and get loss at each point ##
    for ii in range(n_pts):
        for jj in range(n_pts):
            perturb = dir1.mul(vec_len[ii]) + dir2.mul(vec_len[jj])
            # print(perturb.shape)
            perturb = utils.unflatten_like(perturb.t(), model.parameters())
            for i, par in enumerate(model.parameters()):
                if use_cuda:
                    par.data = par.data + perturb[i].cuda()
                else:
                    par.data = par.data + perturb[i]

            loss_surf[ii, jj] = loss_getter(model, dataloader, 
                                            criterion, use_cuda)

            model.load_state_dict(start_pars)

    return loss_surf

def get_plane(basis):
    """
    returns two vectors that define the span of a random plane
    that is in the span of the basis
    """
    n_basis = basis.size(-1)
    wghts = torch.randn(n_basis, 1).to(basis.device)
    dir1 = basis.matmul(wghts)

    wghts = torch.randn(n_basis, 1).to(basis.device)
    dir2 = basis.matmul(wghts)

    ## now gram schmidt these guys ##
    vu = dir2.squeeze().dot(dir1.squeeze())
    uu = dir1.squeeze().dot(dir1.squeeze())

    dir2 = dir2 - dir2.mul(vu).div(uu)

    ## normalize ##
    dir1 = dir1.div(dir1.norm())
    dir2 = dir2.div(dir2.norm())

    return dir1, dir2

In [None]:
ls1 = get_loss_surface(pos_evecs, model, trainloader,
                        criterion, rng=0.1, n_pts=5, use_cuda=use_cuda)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(ls1)
plt.colorbar()