# Training a Convolutional Neural Network for digit recognition

This was built following a tutorial in the wonderful book *Machine Learning with PyTorch and Scikit-Learn* by Raschka, Liu and Mirjalili (2022, Packt Publishing).  In this notebook, we do the following:

1. Use the PyTorch library (https://pytorch.org/) to construct and train a convolutional neural network (CNN) on the MNIST handwritten digit database (http://yann.lecun.com/exdb/mnist/).
2. Deploy the trained model as an interactive web app using the Gradio library (https://gradio.app/).

Note that the Gradio app can be found by visiting my huggingface page: https://huggingface.co/spaces/etweedy/digits

## Import libraries

In [1]:
import torch
import torch.nn as nn
from torch import optim
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import Subset, DataLoader
image_path = 'data'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Load and the MNIST dataset and create PyTorch DataLoaders

We first load the datasets and create training and validation DataLoaders with batch_size=64 samples.
Some remarks:
1. The MNIST images are black and white (i.e. 1 channel) 28x28 pixel images files.  By default, datasets.MNIST() loads in each sample as a tuple (image,label) where image is is a PIL image and label is the ground truth value of the digit.
2. The ToTensor() transform does two things in succession:
    - Converts the image in each tuple to a PyTorch tensor, sending each pixel to a float between 0 and 255.
    - Normalizes the floats to in the tensor to the interval [0,1]

In [10]:
transform = transforms.Compose([transforms.ToTensor()])

In [11]:
data_train = datasets.MNIST(root='data',transform = transform,train=True,download=True)
data_val = datasets.MNIST(root='data',transform = transform, train=False, download=False)

In [12]:
batch_size=64
torch.manual_seed(1)
dl_train = DataLoader(data_train,batch_size,shuffle=True)
dl_val = DataLoader(data_val,batch_size,shuffle=True)

## Construct CNN model

We build a CNN with two convolutive layers with batch normalization, ReLU activation, and 2x2 max-pooling, followed by a flattening layer and two linear layers with a dropout layer between.

In [64]:
model = nn.Sequential()
model.add_module(
    'conv1',
    nn.Conv2d(
        in_channels=1,out_channels=32,
        kernel_size=5,padding=2
    ),
)
model.add_module('bn1',nn.BatchNorm2d(32))
model.add_module('relu1',nn.ReLU())
model.add_module('pool1',nn.MaxPool2d(kernel_size=2))
model.add_module(
    'conv2',
    nn.Conv2d(
        in_channels=32,out_channels=64,
        kernel_size=5,padding=2
    ),
)
model.add_module('bn2',nn.BatchNorm2d(64))
model.add_module('relu2',nn.ReLU())
model.add_module('pool2',nn.MaxPool2d(kernel_size=2))
model.add_module('flatten',nn.Flatten())
model.add_module('fc1',nn.Linear(3136,1024))
model.add_module('dropout',nn.Dropout(p=0.5))
model.add_module('fc2',nn.Linear(1024,10))

model.to(device)

Sequential(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=3136, out_features=1024, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=1024, out_features=10, bias=True)
)

In [65]:
loss_fn = nn.CrossEntropyLoss()
opt = optim.Adam(model.parameters(),lr=0.002)

## Training the CNN

We now construct our training function, which keeps track of training loss and accuracy and validation loss and accuracy after each epoch.  Accuracy values are printed as we progress through training.

In [66]:
def train(model,num_epochs,dl_train,dl_val):
    loss_hist_train = [0]*num_epochs
    acc_hist_train = [0]*num_epochs
    loss_hist_val = [0]*num_epochs
    acc_hist_val = [0]*num_epochs
    for epoch in range(num_epochs):
        model.train()
        for x_batch,y_batch in dl_train:
            x_batch=x_batch.to(device)
            y_batch=y_batch.to(device)
            pred = model(x_batch)
            loss = loss_fn(pred,y_batch)
            loss.backward()
            opt.step()
            opt.zero_grad()
            loss_hist_train[epoch] += loss.item()*y_batch.size(0)
            is_correct=(torch.argmax(pred,dim=1) == y_batch).float()
            acc_hist_train[epoch] += is_correct.sum()
            
        loss_hist_train[epoch] /= len(dl_train.dataset)
        acc_hist_train[epoch] /= len(dl_train.dataset)
        
        model.eval()
    
        with torch.no_grad():
            for x_batch,y_batch in dl_val:
                x_batch=x_batch.to(device)
                y_batch=y_batch.to(device)
                pred = model(x_batch)
                loss = loss_fn(pred,y_batch)
                loss_hist_val[epoch] += loss.item()*y_batch.size(0)
                is_correct=(torch.argmax(pred,dim=1) == y_batch).float()
                acc_hist_val[epoch] += is_correct.sum()
            loss_hist_val[epoch] /= len(dl_val.dataset)
            acc_hist_val[epoch] /= len(dl_val.dataset)
        
            print(f' Epoch {epoch+1} ---- train accuracy: {acc_hist_train[epoch]:.4f} ---- val accuracy: {acc_hist_val[epoch]:.4f}')
        
    return loss_hist_train,loss_hist_val,acc_hist_train,acc_hist_val

In [67]:
torch.manual_seed(1)
num_epochs=20

After 20 epochs, we are able to reach a validation accuracy of 99.18%!

In [68]:
hist = train(model,num_epochs,dl_train,dl_val)

 Epoch 1 ---- train accuracy: 0.9371 ---- val accuracy: 0.9852
 Epoch 2 ---- train accuracy: 0.9832 ---- val accuracy: 0.9870
 Epoch 3 ---- train accuracy: 0.9864 ---- val accuracy: 0.9916
 Epoch 4 ---- train accuracy: 0.9895 ---- val accuracy: 0.9823
 Epoch 5 ---- train accuracy: 0.9894 ---- val accuracy: 0.9894
 Epoch 6 ---- train accuracy: 0.9903 ---- val accuracy: 0.9876
 Epoch 7 ---- train accuracy: 0.9901 ---- val accuracy: 0.9884
 Epoch 8 ---- train accuracy: 0.9905 ---- val accuracy: 0.9914
 Epoch 9 ---- train accuracy: 0.9919 ---- val accuracy: 0.9904
 Epoch 10 ---- train accuracy: 0.9920 ---- val accuracy: 0.9919
 Epoch 11 ---- train accuracy: 0.9931 ---- val accuracy: 0.9895
 Epoch 12 ---- train accuracy: 0.9930 ---- val accuracy: 0.9895
 Epoch 13 ---- train accuracy: 0.9934 ---- val accuracy: 0.9879
 Epoch 14 ---- train accuracy: 0.9937 ---- val accuracy: 0.9896
 Epoch 15 ---- train accuracy: 0.9946 ---- val accuracy: 0.9903
 Epoch 16 ---- train accuracy: 0.9944 ---- val ac

## Saving the model

In [90]:
torch.save(model,'mnist_model.pth')

In [91]:
torch.save(model.state_dict(),'mnist_model_weights.pth')

## Gradio app implementation

The following code creates an interactive Gradio app, which will ask the user to draw a digit on an in-browser sketchpad and then guess the digit using the model we've trained.  See this link for an implementation hosted on my huggingface account: https://huggingface.co/spaces/etweedy/digits

There are several steps to this implementation:
1. Write a prediction function which will take in an image from the Gradio sketchpad and make a prediction of the digit using our model.
2. Write the code that launchest the Gradio interface.

In [None]:
def predict(img):
    x = torch.tensor(img, dtype=torch.float32).unsqueeze(0).unsqueeze(0) / 255.
    with torch.no_grad():
        pred = model(x)[0]
    return int(pred.argmax())

Note that if you're running this notebook on the cloud (Google Collab, Kaggle, Paperspace, etc.) the link created by the below code may not work.  It creates a locally hosted version of your web app which you can open and play with in your browser, if you are running this notebook on your local machine.

It's easy to share your machine learning project as a Gradio space on huggingface! More info: https://huggingface.co/blog/gradio-spaces

In [None]:
title = "Guess that digit"
description = "Draw your favorite base-10 digit (0-9) and click submit - I'll try to guess what you drew! I do a bit better if you're not too messy and your digit is fairly centered."
gr.Interface(fn=predict, 
             inputs="sketchpad",
             outputs="label",
             title = title,
             description = description,
              ).launch()