This notebook is based on an exercise from Chapter 3 of the book [Deep Learning for Coders with Fastai and PyTorch: AI Applications Without a PhD](https://www.amazon.com/Deep-Learning-Coders-fastai-PyTorch/dp/1492045527). The chapter details the implementation of simple neural network for binary classification, using a subset of the MNIST dataset.

Here we're extending that example to multi-class classification using the complete MNIST dataset. We'll use the same architectures presented in Chapter 4, which, while not state-of-the-art for the complete MNIST classification problem, serve as a valuable learning experience.

For a more accurate approach, see for instance [Beginners guide to MNIST with fast.ai](https://www.kaggle.com/code/christianwallenwein/beginners-guide-to-mnist-with-fast-ai).

In [None]:
#hide
!pip install -Uqq fastai nbdev

from fastai import *
from fastai.vision import *
from fastai.vision.all import *

import torch
import torch.nn.functional as F

We start by downloading and unzipping the full MNIST dataset. The dataset is split in two folders, `training` and `testing`. Under each we have a folder for each digit `0` - `9`

In [None]:
path = untar_data(URLs.MNIST)
Path.BASE_PATH = path
sorted((path/'training').ls())

We'll create a list of tensors, where each tensor contains the data for all the images for a given digit - each tensor has shape [N,28,28] where N is the number of images for a given digit and each image is a matrix of 28x28 pixels. For that we'll create a couple of auxiliary functions:

`path_to_tensor` returns a tensor with all the images in a given path.

`path_to_tensor_list` returns a list of tensors, one for each directory under `path`. Since we sort the result of `path.ls`, the tensor list will also be sorted, from digit 0 to 9. i.e. the fist list element contains the tensor with all the images of the digit 0, and so on.

In [None]:
def path_to_tensor(path):
    return torch.stack([tensor(Image.open(o)) for o in path.ls()])
    
def path_to_tensor_list(path):
    return list(map(path_to_tensor, sorted(path.ls())))

In [None]:
train_tensors = path_to_tensor_list(path/'training')
test_tensors = path_to_tensor_list(path/'testing')

train_tensors[0].shape, test_tensors[0].shape

In [None]:
show_image(train_tensors[7][0])
show_image(test_tensors[7][0])

We'll now concatenate all tensors to create a consolidated train and test tensor. We also change the shape - so that each image is given as a sequence of 784 pixels - and we normalize pixel values.

In [None]:
train_x = torch.cat(train_tensors).view(-1, 28*28).float()/255
test_x = torch.cat(test_tensors).view(-1, 28*28).float()/255

train_x.shape, test_x.shape

In [None]:
show_image(train_x[13000].view(28,28))

Now let's create the target tensors using one-hot encoding. The target tensor will have 10 columns, one for each class (the digits 0 to 9). For that we'll use the `one_hot`PyTorch function.

As a first step we'll create a list of tuples containing the number of images for each digit.

In [None]:
train_labels = list(enumerate(map(lambda t: t.shape[0], train_tensors)))
train_labels

`labels_to_target` receives the tuple list and returns the target tensors.

In [None]:
def labels_to_target(labels):
    res = torch.cat([torch.full((size,), label, dtype=torch.long) for label, size in labels])
    res = F.one_hot(res).float()
    return res
    

In [None]:
train_y = labels_to_target(train_labels)

train_y[torch.randperm(len(train_y))[:5]]

In [None]:
train_y[13000]

Now the same for the test data:

In [None]:
test_labels = list(enumerate(map(lambda t: t.shape[0], test_tensors)))
test_labels

In [None]:
test_y = labels_to_target(test_labels)

test_y[torch.randperm(len(test_y))[:5]]

We can now create the data sets and loaders. We'll use a splitter to split the train set between taining and validation.

In [None]:
dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape, y

In [None]:
split = TrainTestSplitter(test_size=0.2, random_state=42)
train_dset_indexes, val_dset_indexes = split(dset)
len(train_dset_indexes), len(val_dset_indexes)

In [None]:
train_dset = [dset[i] for i in train_dset_indexes]
valid_dset = [dset[i] for i in val_dset_indexes]

show_image(train_dset[0][0].view(28,28))
print(train_dset[0][1])

In [None]:
train_dl = DataLoader(train_dset, batch_size=256, shuffle=True)
xb,yb = first(train_dl)
xb.shape,yb.shape

In [None]:
valid_dl = DataLoader(valid_dset, batch_size=256, shuffle=True)
xb,yb = first(valid_dl)
xb.shape,yb.shape

## SGD

The training loop, based on the code from Chapter 4:

In [None]:
from torch.nn import init
    
def init_lin_params(in_features, out_features, std=1.0): 
    w = (torch.randn((in_features, out_features))*std).requires_grad_()
    b = (torch.randn(out_features)*std).requires_grad_()
    return w, b


# This replicates the Kaiming parameter initialization implemented in nn.Linear
def init_lin_params_k(in_features, out_features):
    w = torch.empty((in_features, out_features)).requires_grad_()
    b = torch.empty(out_features).requires_grad_()
    init.kaiming_uniform_(w.T, a=math.sqrt(5))
    fan_in, _ = init._calculate_fan_in_and_fan_out(w)
    bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
    init.uniform_(b, -bound, bound)
    return w, b

def calc_grad(xb, yb, model, lossf):
    preds = model(xb)
    loss = lossf(preds, yb)
    loss.backward()
    
def train_epoch(model, lr, params, lossf):
    for xb,yb in train_dl:
        calc_grad(xb, yb, model, lossf)
        for p in params:
            p.data -= p.grad.data * lr
            p.grad.zero_()
            
def batch_accuracy(xb, yb):
    _, preds = torch.max(xb, dim=1)
    _, target = torch.max(yb, dim=1)
    return torch.tensor(torch.sum(preds == target).item() / len(preds))

def validate_epoch(model):
    accs = [batch_accuracy(model(xb), yb) for xb,yb in valid_dl]
    return round(tensor(accs).mean().item(), 4)

def train(model, params, lossf=F.cross_entropy, epochs=50, lr=1):
    for i in range(epochs):
        train_epoch(model, lr, params, lossf)
        print(validate_epoch(model), end=' ')
        
def test(model):
    acc = batch_accuracy(model(test_x), test_y).item()
    return f"{acc:.2f}%"

## Logistic regression

First let's try to train a simple logistic regression model. Note that we're using the Kaming parameter initialization as we get faster convergence relative to random initialization.

In [None]:
def log_reg(xb): 
    res = xb@w1 + b1
    return res.sigmoid()

w1, b1 = init_lin_params_k(28*28,10)

train(log_reg, (w1, b1))

Let's check accuracy with the test set:

In [None]:
print(test(log_reg))

We get above 90% with a simple logistic regression model, not bad! Let's see if we get a better result using a model with two layers.

## Two layer net

In [None]:
def two_layer_net(xb): 
    res = xb@w2 + b2
    res = res.max(tensor(0.0))
    res = res@w3 + b3
    res = F.softmax(res, dim=1)
    return res

w2, b2 = init_lin_params_k(28*28, 32)
w3, b3 = init_lin_params_k(32, 10)

train(two_layer_net, (w2, b2, w3, b3), lossf=F.cross_entropy, epochs=50)

With this model we're getting above 95%. This is in line with the results obtained here with a similar model: [Multi-class classification with MNIST](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/multi-class_classification_with_MNIST.ipynb?hl=en#scrollTo=pedD5GhlDC-y). Let's check with the test set.

In [None]:
print(test(two_layer_net))

## Using a Learner / PyTorch Modules

Using a Learner and Pytorch modules we essentially get to the same results.

In [None]:
import torch.nn as nn

two_layer_nn = nn.Sequential(
    nn.Linear(28*28, 32),
    nn.ReLU(),
    nn.Linear(32,10),
    nn.Softmax(dim=1)
)

dls = DataLoaders(train_dl, valid_dl)

learn = Learner(dls, two_layer_nn, opt_func=SGD,
                loss_func=F.cross_entropy, metrics=batch_accuracy)

In [None]:
learn.fit(50, 1)

In [None]:
print(test(two_layer_nn))

## References

[Deep Learning for Coders with Fastai and PyTorch - Chapter 4](https://github.com/fastai/fastbook/blob/master/04_mnist_basics.ipynb)

[Multi-class classification with MNIST](https://colab.research.google.com/github/google/eng-edu/blob/main/ml/cc/exercises/multi-class_classification_with_MNIST.ipynb?hl=en#scrollTo=pedD5GhlDC-y)
