<h2>Setting up the Environment</h2>

It will be using PyTorch to train a convolutional neural network to recognize MNIST's handwritten digits. PyTorch is a very popular framework for deep learning like Tensorflow, CNTK and Caffe2. But unlike these other frameworks PyTorch has dynamic execution graphs, meaning the computation graph is created on the fly.

In [None]:
import torch
import torchvision

<h2>Preparing the Dataset</h2>

With the imports in place it can go ahead and prepare the data that will be using. But before, the hyperparameters should be defined that will be using for the experiment. Here the number of epochs defines how many times it will loop over the complete training dataset, while learning_rate and momentum are hyperparameters for the optimizer that it will use later on.

In [None]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

For repeatable experiments it has to set random seeds for anything using random number generation - this means numpy and random as well! It's also worth mentioning that cuDNN uses nondeterministic algorithms which can be disabled setting torch.backends.cudnn.enabled = False.

Now it will also need DataLoaders for the dataset. This is where TorchVision comes into play. It let's use load the MNIST dataset in a handy way. It will use a batch_size of 64 for training and size 1000 for testing on this dataset. The values 0.1307 and 0.3081 used for the Normalize() transformation below are the global mean and standard deviation of the MNIST dataset, it will take them as a given here.

TorchVision offers a lot of handy transformations, such as cropping or normalization.

In [None]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./data/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

PyTorch's DataLoader contain a few interesting options other than the dataset and batch size. For example it could use num_workers > 1 to use subprocesses to asynchronously load data or using pinned RAM (via pin_memory) to speed up RAM to GPU transfers. But since these mostly matter when it is using a GPU it can omit them here.

Now let's take a look at some examples. It will use the test_loader for this.

In [None]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

Ok let's see what one test data batch consists of.

In [None]:
example_data.shape

So one test data batch is a  tensor of shape: torch.Size([1000, 1, 28, 28]). This means it has 1000 examples of 28x28 pixels in grayscale (i.e. no rgb channels, hence the one). Let's plot some of them using matplotlib.

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Ground Truth: {}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])
fig

Alright, those shouldn't be too hard to recognize after some training.

<h2>Building the Network</h2>

Now let's go ahead and build our network. It will use two 2-D convolutional layers followed by two fully-connected (or linear) layers. As activation function it will choose rectified linear units (ReLUs in short) and as a means of regularization it will use two dropout layers. In PyTorch a nice way to build a network is by creating a new class for the network that it should to build. Let's import a few submodules here for more readable code.

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x, dim = 1)

Broadly speaking think of the torch.nn layers as which contain trainable parameters while torch.nn.functional are purely functional. The forward() pass defines the way it computes output using the given layers and functions. It would be perfectly fine to print out tensors somewhere in the forward pass for easier debugging. This comes in handy when experimenting with more complex models. Note that the forward pass could make use of e.g. a member variable or even the data itself to determine the execution path - and it can also make use of multiple arguments!

Now let's initialize the network and the optimizer.

In [None]:
network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

Note: If it was using a GPU for training, it should have also sent the network parameters to the GPU using e.g. network.cuda(). It is important to transfer the network's parameters to the appropriate device before passing them to the optimizer, otherwise the optimizer will not be able to keep track of them in the right way.

<h2>Training the Model</h2>

Time to build a training loop. First, to make sure that network is in training mode. Then it iterates over all training data once per epoch. Loading the individual batches is handled by the DataLoader. First it will manually set the gradients to zero using optimizer.zero_grad() since PyTorch by default accumulates gradients. It then produces the output of our network (forward pass) and computes a negative log-likelihodd loss between the output and the ground truth label. The backward() call it now collects a new set of gradients which will be propagated back into each of the network's parameters using optimizer.step().

It will also keep track of the progress with some printouts. In order to create a nice training curve later on it also creates two lists for saving training and testing losses. On the x-axis it will display the number of training examples the network has seen during training. 

In [None]:
train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]
test_counter

In [None]:
len(train_loader.dataset)
len(train_losses)

It will run the test loop once before even starting the training to see what accuracy/loss it achieves just with randomly initialized network parameters.

In [None]:
def train(epoch):
    network.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = network(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
            epoch, batch_idx * len(data), len(train_loader.dataset),
            100. * batch_idx / len(train_loader), loss.item()))
            train_losses.append(loss.item())
            train_counter.append(
                (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
            torch.save(network.state_dict(), './results/model.pth')
            torch.save(optimizer.state_dict(), './results/optimizer.pth')

Neural network modules as well as optimizers have the ability to save and load their internal state using .state_dict(). With this it can continue training from previously saved state dicts if needed - it would just need to call .load_state_dict(state_dict). 

Now for the test loop. Here it sums up the test loss and keep track of correctly classified digits to compute the accuracy of the network. 

In [None]:
def test():
    network.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = network(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
            100. * correct / len(test_loader.dataset)))

Using the context manager no_grad() it can avoid storing the computations done producing the output of the network in the computation graph.

Time to run the training! It will manually add a test() call before it loops over n_epochs to evaluate the model with randomly initialized parameters.

In [None]:
test()
for epoch in range(1, n_epochs + 1):
    train(epoch)
    test()

<h2>Evaluating the Model's Performance</h2>

And that's it. With just 3 epochs of training it already managed to achieve 97% accuracy on the test set! It started out with randomly initialized parameters and as expected only got about 10% accuracy on the test set before starting the training.

Let's plot the training curve.

In [None]:
def plt_figure(train_counter=train_counter, train_losses=train_losses):
    plt.plot(train_counter, train_losses, color='blue')
    plt.scatter(test_counter, test_losses, color='red')
    plt.legend(['Train Loss', 'Test Loss'], loc='upper right')
    plt.xlabel('number of training examples seen')
    plt.ylabel('negative log likelihood loss')
    plt.show()

In [None]:
plt_figure()

Judging from the training curve() it looks like it could even continue training for a few more epochs!

But before that let's again look at a few examples as it did earlier and compare the model's output.

In [None]:
with torch.no_grad():
    output = network(example_data)

In [None]:
fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
    plt.title("Prediction: {}".format(
        output.data.max(1, keepdim=True)[1][i].item()))
    plt.xticks([])
    plt.yticks([])
fig

THe model's predictions seem to be on point for those examples!

<h2>Continued Training from Checkpoints</h2>

Now let's continue training the network, or rather see how it can continue training from the state_dicts it saved during the first training run. It will initialize a new set of network and optimizers.

In [None]:
continued_network = Net()
continued_optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                                momentum=momentum)

Using .load_state_dict() it can now load the internal state of the network and optimizer when it last saved them.

In [None]:
network_state_dict = torch.load('./results/model.pth')
continued_network.load_state_dict(network_state_dict)

optimizer_state_dict = torch.load('./results/optimizer.pth')
continued_optimizer.load_state_dict(optimizer_state_dict)

Again running a training loop should immediately pick up the training where it left it. To check on that let's simply use the same lists as before to keep track of the loss values. Due to the way it constructed the test counter for the number of training examples seen it manually has to append to it here.

In [None]:
for i in range(n_epochs+1,n_epochs+1+5):
    test_counter.append(i*len(train_loader.dataset))
    train(i)
    test()

Great! It again increases in test set accuracy from epoch to epoch. Let's visualize this to further inspect the training progress.

In [None]:
plt_figure()