# Environment Setup

In [1]:
%load_ext autoreload
%autoreload 2

# libraries
import copy
import time
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from tqdm import tqdm

matplotlib.rcParams['figure.figsize'] = [18, 12]

# code from this library - import the lines module
import loss_landscapes
import loss_landscapes.metrics
from loss_landscapes.model_interface.model_wrapper import ModelWrapper, wrap_model
from loss_landscapes.model_interface.model_parameters import ModelParameters, rand_u_like, rand_n_like, orthogonal_to
from loss_landscapes.contrib.functions import SimpleWarmupCaller, SimpleLossEvalCaller, log_refined_loss, _pacbayes_sigma

## 1. Preliminary: Classifying MNIST

This notebook demonstrates how to accomplish a simple task: visualizing the loss landscape of a small fully connected feed-forward neural network on the MNIST image classification task. In this section the preliminaries (the model and the training procedure) are setup.

In [2]:
# training hyperparameters
IN_DIM = 28 * 28
OUT_DIM = 10
LR = 10 ** -2
BATCH_SIZE = 512
EPOCHS = 25
# contour plot resolution
STEPS = 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

The cells in this section contain no code specific to the `loss-landscapes` library.

In [3]:
class MLPSmall(torch.nn.Module):
    """ Fully connected feed-forward neural network with one hidden layer. """
    def __init__(self, x_dim, y_dim):
        super().__init__()
        self.linear_1 = torch.nn.Linear(x_dim, 32)
        self.bn = torch.nn.BatchNorm1d(32)
        self.linear_2 = torch.nn.Linear(32, y_dim)

    def forward(self, x):
        h = F.relu(self.linear_1(x))
        h = self.bn(h)
        return F.softmax(self.linear_2(h), dim=1)


class Flatten(object):
    """ Transforms a PIL image to a flat numpy array. """
    def __call__(self, sample):
        return np.array(sample, dtype=np.float32).flatten()    
    

def train(model, optimizer, criterion, train_loader, epochs, device):
    """ Trains the given model with the given optimizer, loss function, etc. """
    model.train()
    # train model
    for _ in tqdm(range(epochs), 'Training'):
        for count, batch in enumerate(train_loader, 0):
            optimizer.zero_grad()
            x, y = batch
            x = x.to(device)
            y = y.to(device)

            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()

    model.eval()

We then create the model and an instance of the MNIST dataset.

In [4]:
# download MNIST and setup data loaders
mnist_train = datasets.MNIST(root='/global/cfs/cdirs/m636/geshi/data/', train=True, download=True, transform=Flatten())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=BATCH_SIZE, shuffle=False)

# define model
model = MLPSmall(IN_DIM, OUT_DIM)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.CrossEntropyLoss()

In [5]:
# stores the initial point in parameter space
model_initial = copy.deepcopy(model)

In [6]:
train(model, optimizer, criterion, train_loader, EPOCHS, device)

model_final = copy.deepcopy(model)

Training: 100%|██████████| 25/25 [00:43<00:00,  1.74s/it]


## 2. Check Pac Bayes Distance

In [7]:
def test_accuracy(test_loader, net):
    """Evaluate testset accuracy of a model."""
    net.eval()
    acc_sum, count = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            # send data to the GPU if cuda is availabel
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()
            count += inputs.size(0)
            
            outputs = net(inputs)
            _, preds = torch.max(outputs, 1)

            # labels = labels.long()
            acc_sum += torch.sum(preds == labels.data).item()
    return acc_sum/count

In [8]:
since = time.time()
acc = test_accuracy(train_loader, model_final)
print('time cost ', time.time()-since)
print(acc)

time cost  1.6319406032562256
0.9814666666666667


In [9]:
optimal_dist = _pacbayes_sigma(
    model_final,
    2,
    train_loader,
    accuracy = 0.981,
    search_depth = 10,
    montecarlo_samples = 10,
    accuracy_displacement = 0.1,
    displacement_tolerance = 1e-2,
    n_dim = 2,
    random = 'normal',
    normalization = 'layer',
    )

In [10]:
optimal_dist

0.5625