## MNIST: binary CNN implementation & training

In this tutorial, we will work with a simple CNN trained on two classes of MNIST.

Below, we implement the standard pipeline and save the model for later inference and explaining.
With the code and comments, the notebook is pretty much self-explaining, so we add additional text only when necessary.

In [11]:
import numpy as np
import torch 
import torch.nn as nn
from data_and_models import mnist_binary_cnn,create_mnist_set # implemented in a separate script for later use
from torch.utils.data import DataLoader

In [12]:
classes=[1,8] # we'll work with these only!

In [13]:
trainset=create_mnist_set(root='./data',train=True,classes=classes)

In [14]:
testset=create_mnist_set(root='./data',train=False,classes=classes)

In [15]:
train_dataloader=DataLoader(trainset,batch_size=50,shuffle=True)

In [16]:
test_dataloader=DataLoader(testset,batch_size=50,shuffle=True)

In [17]:
def binarize_labels(labels,classes): # classes[0] becomes ~0, classes[1] ~1
    
    # the reason why we use it: NN outputs a single probability: need to binarize class labels

    return ((labels-classes[0])/(classes[1]-classes[0])).to(torch.float32)

In [18]:
def trainstep(model,optimizer,batch):

    labels=binarize_labels(batch[1],classes)

    optimizer.zero_grad()
    probs=model(batch[0])[:,0]

    loss=model.loss(probs,labels)

    loss.backward()
    optimizer.step()

    return loss.detach()



In [19]:
def accuracy(model,batch):

    n=batch[0].shape[0]
    
    labels=binarize_labels(batch[1],classes)
    probabilities=model(batch[0])[:,0]
    
    correct=0

    for i in range(n):
        if probabilities[i]>0.5 and labels[i]>0.5:
            correct+=1
        elif probabilities[i]<=0.5 and labels[i]<0.5:
            correct+=1
    return correct/n


In [28]:
model=mnist_binary_cnn() # standard choice to rescale image and crossentropy - see the defaults

In [29]:
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)

And the training comes, finally. Note that our objective was to train a simple model, so we didn't bother with setting up validation cycles, etc.

In [30]:
for i in range(10): # for such a simple problem, 10 epoch shall do
    full_loss=0
    full_accuracy=0
    for j,batch in enumerate(train_dataloader):
        loss=trainstep(model,optimizer,batch)
        full_loss=full_loss*j/(j+1)+loss.detach()/(j+1) # compute average in a 'sliding' manner, maybe familiar from RL
        full_accuracy=full_accuracy*j/(j+1)+accuracy(model,batch)/(j+1)
    print('epoch %d completed'%(i+1))
    print('loss: %f' % (full_loss))
    print('accuracy: %f'%(full_accuracy))



epoch 1 completed
loss: 0.659585
accuracy: 0.720292
epoch 2 completed
loss: 0.362096
accuracy: 0.929168
epoch 3 completed
loss: 0.150800
accuracy: 0.956006
epoch 4 completed
loss: 0.106130
accuracy: 0.965397
epoch 5 completed
loss: 0.087141
accuracy: 0.970371
epoch 6 completed
loss: 0.075571
accuracy: 0.974908
epoch 7 completed
loss: 0.067968
accuracy: 0.977619
epoch 8 completed
loss: 0.062334
accuracy: 0.979273
epoch 9 completed
loss: 0.057479
accuracy: 0.980146
epoch 10 completed
loss: 0.053844
accuracy: 0.981429


To get some feeling of how much of an overfit regime we are in, calculate accuracy on a valid batch:

In [31]:
val_batch=next(iter(test_dataloader))

In [32]:
accuracy(model,val_batch)

0.96

Finally, we store the model and the classes so that in later parts, we only need to reload them without re-training 

In [33]:
import pickle

In [34]:
with open('models/mnist_binary_classes.txt','wb') as f:
    pickle.dump(classes,f)

In [35]:
torch.save(model.state_dict(),'models/mnist_binary_model.pth')