# How Neural Network Design Choices Affect Loss Landscapes
In this notebook, we explore the appearance of the loss landscapes of neural networks on the MNIST image classification task, under a number of different transformations. We will be using the `loss-landscapes` package to compute low-dimensional approximations of the loss function. We will be implementing the networks in PyTorch, which is the only supported neural network library as of March 2019 (more will be added later).

In [1]:
# add project source to path for use in the notebook
import os
import sys
import copy
import itertools

The code above is required to add the module source to Python's `path` variable, so that this notebook can import the `loss-landscapes` package. It is not required in general.

In [2]:
# libraries
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from tqdm import tqdm

matplotlib.rcParams['figure.figsize'] = [20, 8]

# code from this library - import the lines module
import loss_landscapes.compute

Finally, before we begin, we set some hyperparameters as constants for ease of reference.

In [3]:
# input dimension and output dimension for an MNIST classifier
IN_DIM = 28 * 28
OUT_DIM = 10
# training settings
LR = 10 ** -3
EPOCHS = 1

## FCFF-NN Loss Landscapes on MNIST Classification Tasks
We will be exploring the effect of a number of different architectural design choices on the loss landscapes of a fully connected feedforward neural network, in MNIST image classification tasks. We begin by defining a fully connected feedforward neural network in Pytorch, as well as a flattening transformation to be passed to the MNIST dataset loader. We also define a function for obtaining the model's loss in its current state. Note that this isn't an "evaluation" function where we'd want to use a test set - we're specifically going to evaluate on the train set, because we wish to visualize the loss landscape experienced by the model during training.

In [4]:
class MLP(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_1 = torch.nn.Linear(IN_DIM, 512)
        self.linear_2 = torch.nn.Linear(512, 256)
        self.linear_3 = torch.nn.Linear(256, 128)
        self.linear_4 = torch.nn.Linear(128, 64)
        self.linear_5 = torch.nn.Linear(64, OUT_DIM)
        self.softmax = torch.nn.Softmax(dim=0)
        
    def forward(self, x):
        x = F.relu(self.linear_1(x))
        x = F.relu(self.linear_2(x))
        x = F.relu(self.linear_3(x))
        x = F.relu(self.linear_4(x))
        x = self.softmax(self.linear_5(x))
        return x
    

class Flatten(object):
    """ Transforms a PIL image to a flat numpy array. """
    def __init__(self):
        pass

    def __call__(self, sample):
        return np.array(sample, dtype=np.float32).flatten()
    

def evaluate(model):
    mnist_train = datasets.MNIST(root='../data/', train=True, download=True, transform=Flatten())
    trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=32, shuffle=True)
    
    criterion = torch.nn.CrossEntropyLoss()
    
    average_loss = 0
    
    # only evaluate on 10 batches for speed - ideally you'd want to evaluate on all data
    for batch in itertools.islice(trainloader, 10):
        x, y = batch
        
        pred = model(x)
        loss = criterion(pred, y)
        average_loss += loss
    
    average_loss /= 10
    return average_loss
    

Now we can carry out a few experiments.

### Batch Size
In this first experiment we explore the effect of batch size on the loss landscape of our neural network when learning a straightforward MNIST classifier. To do so, we will train the model with a batch size of 1, and keep the model's initial and final parameters, and plot a linear interpolation between the two points.

In [5]:
# download MNIST
mnist_train = datasets.MNIST(root='../data/', train=True, download=True, transform=Flatten())
trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=1, shuffle=True)

# define model
model = MLP()
optimizer = optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.CrossEntropyLoss()

# save initial state
model_initial = copy.deepcopy(model)
params_initial = copy.deepcopy(list(model_initial.parameters()))

# train model
for epoch in tqdm(range(EPOCHS)):
    for count, batch in enumerate(tqdm(trainloader, 'Batches'), 0):
        if count == 100:
            break
            
        x, y = batch
        optimizer.zero_grad()
        
        pred = model(x)
        loss = criterion(pred, y)
        loss.backward()
        optimizer.step()

# save final state
model_final = copy.deepcopy(model)
params_final = copy.deepcopy(list(model.parameters()))

  0%|                                                                                            | 0/1 [00:00<?, ?it/s]
Batches:   0%|                                                                               | 0/60000 [00:00<?, ?it/s]
Batches:   0%|                                                                       | 5/60000 [00:00<21:35, 46.30it/s]
Batches:   0%|                                                                      | 10/60000 [00:00<21:53, 45.66it/s]
Batches:   0%|                                                                      | 16/60000 [00:00<20:28, 48.83it/s]
Batches:   0%|                                                                      | 26/60000 [00:00<17:26, 57.29it/s]
Batches:   0%|                                                                      | 35/60000 [00:00<15:34, 64.16it/s]
Batches:   0%|                                                                      | 45/60000 [00:00<14:13, 70.22it/s]
Batches:   0%|                          

Now, we use the `loss-landscapes` library to plot the loss along a line in parameter space from the initial parameters to the final parameters.

In [6]:
losses = loss_landscapes.compute.linear_interpolation(model_initial, model_final, evaluate)
plt.plot(loss_data)
plt.title('Linear Interpolation of Loss')
plt.xlabel('Parameter Space')
plt.ylabel('Loss')
plt.show()

0.0


ZeroDivisionError: float division by zero

A linear interpolation plot, as seen above, computes the model's loss at discrete intervals along a line between two points in parameter space. A common use for such a plot is computing the loss along the "straight line path" from the model's initialization to the model's final (trained) parameters. There is no guarantee that the path followed by the optimization procedure was "close" to this line. So what happens if we increase the batch size?