# Hello World: Training a Simple Model on the MNIST dataset

In this notebook we are going to get familiar with using a deep learning library like Tensorflow to train a simple neural network. The network will be trained on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) which contains small images of handwritten numerical digits. By the end of this training, the model should be able to accurately classify images with numerical digits.

Training a network on the MNIST dataset has become the 'hello world' of machine learning. 

The following is based on the [PyTorch Quickstart](https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html), have a look on these pages for more!

## Jupyter (Locally)

The recommended way is to clone the whole repo. You will need `pytorch`, `numpy` and `matplotlib` installed. The same commands can be used as for Google Colab below, except in a terminal pointing to the repository, and without the leading `!`.

## Google Colab: Two Workflows

### 1. Clone the repo inside your Google Drive

For this, you need to mount your drive to the machine, like so:

```python
from google.colab import drive
drive.mount('/content/drive')

# change directory using the os module
import os
os.chdir('drive/My Drive/')
os.listdir()             # shows the contents of the current dir, you can use chdir again after that
# os.mkdir("IS53055B-DMLCP") # creating a directory
# os.chdir("IS53055B-DMLCP") # moving to this directory
# os.getcwd()            # printing the current directory
```

You can use git in Colab:
```python
!git clone https://github.com/jchwenger/DMLCP
```

To pull updates from the upstream repository without losing your work:
```python
!git stash     # temporary stashing away your changes
!git pull      # importing the update from github
!git stash pop # reimporting your changes, deleting the stash
```

### 2. Using this notebook as a standalone file

On Google Colab you will need to download things:

```python
!curl -O https://raw.githubusercontent.com/jchwenger/DMLCP/main/python/images/3.png
!curl -O https://raw.githubusercontent.com/jchwenger/DMLCP/main/python/images/4.png
# for FashionMNIST
!curl -O https://raw.githubusercontent.com/jchwenger/DMLCP/main/python/images/handbag.png
!mkdir images
!mv 3.png 4.png handbag.png images
!unzip images.zip
```

But to use the model created by this notebook in another notebook, you will need to either manually download/upload the model file (top left bar has a file explorer), or setup your notebook to mount (= connect to) a Google drive (using the code above).

## Tutorials & references

Python:
- [Working With Files](https://realpython.com/working-with-files-in-python/)
- [Python's pathlib Module: Taming the File System](https://realpython.com/python-pathlib/)

Colab:
- [Loading/Saving data on Google Colab](https://colab.research.google.com/notebooks/io.ipynb)

Pytorch:
- [Transforming and augmenting images](https://pytorch.org/vision/main/transforms.html)
- [Getting started with transforms v2](https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_getting_started.html#sphx-glr-auto-examples-transforms-plot-transforms-getting-started-py)



## Imports

In [None]:
import pathlib
from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

import torchvision as tv
from torchvision.transforms import v2

# Get cpu, gpu or mps device for training
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

## Data processing: walkthrough

MNIST contains images of single digits, so 10 classes, from 0 to 9.

All images are 28 by 28 pixels, black and white (1 channel).

In [None]:
# # Model / data parameters

NUM_CLASSES = 10
INPUT_SHAPE = (1,28,28)

# fixed directory structure -------------
DATASETS_DIR = pathlib.Path("datasets")
DATASETS_DIR.mkdir(exist_ok=True)

MODELS_DIR = pathlib.Path("models")
MODELS_DIR.mkdir(exist_ok=True)
# ----------------------------------------

MODEL_NAME = "dense_mnist" # change accordingly

MNIST_DIR = MODELS_DIR / MODEL_NAME

See [here](https://pytorch.org/vision/stable/transforms.html#performance-considerations) for `ToImage` and `ToDtype` to converts a PIL image or NumPy `ndarray` into a `FloatTensor`. and scales the image’s pixel intensity values in the range .[0., 1.].

In [None]:
transforms = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True)
])

train_data = tv.datasets.MNIST(
    root=DATASETS_DIR,
    train=True,
    download=True,
    transform=transforms,
)

test_data = tv.datasets.MNIST(
    root=DATASETS_DIR,
    train=False,
    download=True,
    transform=transforms,
)

Why do we split our data like this? That is because we want to see how well our model performs on data *it was not trained on* (the test set)!

### A look at our datasets

In [None]:
print(train_data)
print()
print(test_data)

In [None]:
print("\n".join(train_data.classes)) # join an array into a string

In [None]:
print(train_data.data.shape, train_data.data.dtype)
print(train_data.targets.shape, train_data.targets.dtype)

Note the difference between the original data type and range, and what happens when you 'call' the dataset to extract a batch, as the model will do:

In [None]:
print(f"Original data type:    {train_data.data.dtype}")
print(f"range:                 [{train_data.data.min().item()}: {train_data.data.max().item()}]")
print(f"Transformed data type: {train_data[0][0].dtype}")
print(f"range:                 [{train_data[0][0].min().item()}: {train_data[0][0].max().item()}]")
print()

print(f"Original label shape:    {train_data.targets.shape} (60k integers)")
print(f"dtype:                   {train_data.targets.dtype}")

In [None]:
torch.set_printoptions(linewidth=150) # prevent wrapping
print(f"This should be a {train_data.classes[test_data.targets[0]]}...")
print()
print(test_data.data[0])

In [None]:
test_data[0][1]

In [None]:
# We can also display the array as an image with matplotlib!
def display_sample(dataset, index, pred=None):
    plt.figure()
    if pred is not None:
        plt.title(f"Label: {dataset.classes[dataset.targets[index]]}, Prediction: {dataset.classes[pred]} | {'CORRECT' if dataset.targets[index] == pred else 'WRONG'}")
    else:
        plt.title(f"Label: {dataset.classes[dataset.targets[index]]}")
    plt.imshow(dataset.data[index], cmap='gray')
    plt.axis("off")
    plt.show()

display_sample(test_data, 0)

### Note: One-Hot Encoding

In PyTorch, the implementation of the loss we are using, the cross-entropy, accepts integer labels, whereas in various cases (and in the default version of Keras), that same loss accepts one-hot vectors.

In [None]:
target_transform = tv.transforms.Lambda(
    lambda y: torch.zeros(NUM_CLASSES, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1)
)
x = 3
print(x)
print(target_transform(x)) # one-hot representation

## Data Loaders

In [None]:
BATCH_SIZE = 64

# Create data loaders.
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)

for X, y in test_dataloader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

## Our Workflow

1. **Model definition**: what kind of model do we want? Create a blueprint.
2. **Loss & Optimizer**: tell TF to build the model for us.
3. **Define our Training loop**    
   Also: _Test before training_ (optional): how lousy are we before we start?
4. **Training**: aka 'fitting' the model to the data
5. **Testing**: how good are we now?

## Model definition

[Model layers](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html#nn-relu)

In [None]:
# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten() # [1, 28, 28] -> [1, 28*28]
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(INPUT_SHAPE[1] * INPUT_SHAPE[2], 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, NUM_CLASSES)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

model = NeuralNetwork().to(device)

## A look at our model

In [None]:
print(f"Model structure:")
print(model)
print()

print("Layers:")
for name, param in model.named_parameters():
    print(f" - {name} | Shape: {param.shape}")

print()
print(f"Our model has {sum(p.numel() for p in model.parameters())} parameters.")

## 2. Loss & Optimizer

The **loss** is how we measure how good our performance is. `CrossEntropy` means:
- **crossentropy**: in probability, the cross-entropy loss is a measure of how two probability distributions differ. It calculates the 'distance' between our predictions (a probability distribution) and our labels (*also* a probability distribution, with a 1 where the ground truth is, and zero everywhere else).

The **optimizer** will take this loss, and change the parameters of the network in order to improve its preformance. The [Adam](https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam) optimizer usually works well out of the box (although it requires a fair amount of memory). You can try different [optimizers](https://pytorch.org/docs/stable/optim.html) from the PyTorch API.

[nn.CrossEntropyLoss()](https://pytorch.org/tutorials/beginner/basics/optimization_tutorial.html#loss-function) does not expect one-hot vector labels (as e.g. in Keras/Tensorflow), only integers!

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9) # you can also add momentum=0.9

# other optimizers are available
# optimizer = torch.optim.RMSprop(model.parameters(), lr=1e-3)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

## 3. Define our training loop

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    losses, accs = [], []
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        
        # 0: data & target to device
        X, y = X.to(device), y.to(device)

        # 1: prediction
        pred = model(X)

        # 2: loss
        loss = loss_fn(pred, y)
    
        # 3: 'backward' | Backpropagation!
        loss.backward()

        # 4: 'step'
        optimizer.step()

        # 5: 'zero grad' (otherwise the gradients remain there)
        optimizer.zero_grad()

        # Logging & saving history

        # save losses
        losses.append(loss.item())
        # save our accuracy
        accs.append((pred.argmax(1) == y).type(torch.float).mean().item())
        
        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    print()
    return losses, accs

In [None]:
def test(dataloader, model, loss_fn):
    
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    losses, accs = [], []
    
    # no gradients: we are not training!
    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            # to device
            X, y = X.to(device), y.to(device)
            # prediction
            pred = model(X)
            # accumulate loss
            t_l = loss_fn(pred, y).item()
            test_loss += t_l

            # accumulate our accuracy
            a = (pred.argmax(1) == y).type(torch.float)
            correct += a.sum().item()

            # save loss and acc
            losses.append(t_l)
            accs.append(a.mean().item())
    
    # average loss & results
    test_loss /= num_batches
    correct /= size
    
    print("Test Error:")
    print(f"Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}")
    print()
    
    return losses, accs, correct

### Before training: how good (bad) is our untrained model?

In [None]:
_, _, _ = test(test_dataloader, model, loss_fn)

In [None]:
model.eval()
n = torch.randint(0, len(test_data), (1,)).item()
x, y = test_data[n][0], test_data[n][1] 
with torch.no_grad():
    x = x.to(device)
    pred = model(x)
    predicted, actual = pred[0].argmax(0), y
    display_sample(test_data, n, pred=predicted.item())

## 4. Training!

There are two parameters we need to define, the `batch_size` and the number of `epochs`.

The number of `epochs` defines how many iterations we perform over the dataset over training. The more epochs in training we perform, the longer training is going to take, but it often (but not always) leads to better performance.

The `batch_size` defines how many data samples we process in parallel during training, this helps speed up training if we use a bigger batch size (but is dependent on the size of the memory of our computer). Using a higher batch size generally leads to better results training, as the weights are updated based on the loss of the whole batch, which leads to more stable training than if we were to update the weights after each single example. Training in batches is a form of *regularisation* – something that will come up again and again with different tricks for getting the best performance out of training.

In [None]:
epochs = 5
train_losses, train_accs, test_losses, test_accs = [], [], [], []
for t in range(epochs):
    print(f"Epoch {t+1}")
    print("-------------------------------")
    train_l, train_a = train(train_dataloader, model, loss_fn, optimizer)
    test_l, test_a, _ = test(test_dataloader, model, loss_fn)
    # save history
    train_losses.extend(train_l)
    train_accs.extend(train_a)
    test_losses.extend(test_l)
    test_accs.extend(test_a)
print("Done!")

## 5. After training: evaluating again (for real)

In [None]:
_ = test(test_dataloader, model, loss_fn)

## Using our model (with an actual input image)


See [Compose](https://pytorch.org/vision/main/generated/torchvision.transforms.v2.Compose.html) and [Transforming and Augmenting Images](https://pytorch.org/vision/stable/transforms.html).

In [None]:
img = Image.open('images/3.png') # try also images/4.png

transforms = v2.Compose([  
    tv.transforms.Grayscale(num_output_channels=1),
    tv.transforms.Resize(size=(28,28), antialias=True),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True), # from [0,255] to [0,1]
])

input = transforms(img)
input = input.to(device)

print(f"Input shape: {input.shape}")

In [None]:
model.eval()
with torch.no_grad():
    predictions = nn.Softmax(dim=-1)(model(input)).cpu().numpy()
print(f"Our predictions (shape: {predictions.shape})")
print(predictions)

We can plot our predictions using a [bar chart](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.bar.html) (sometimes the net is so confident that you will really see just one bar, the other numbers being so small!)

In [None]:
plt.figure(figsize=(14,5))
plt.title("Predictions")
xs = train_data.classes   # 0 to 9 for xs, our ys are our predictions
plt.bar(xs, predictions[0])             # a bar chart
plt.xticks(xs)
plt.show()

In [None]:
# note that predictions is still *batched* (shape: (1,10)), we need to fetch the first array
predicted = np.argmax(predictions[0]) # argmax: the *index* of the highest prediction

plt.figure()
plt.title(f'Predicted number: {train_data.classes[predicted]}') # use the predicted category in the title
plt.imshow(img, cmap="gray")
plt.axis("off")
plt.show()

## Saving & Loading Models

```python
# save (reload using torch.jit.load)
torch.jit.save(torch.jit.script(model), MNIST_DIR / f"{MODEL_NAME}_scripted.pt")

# save (reload using model.load_state_dict, requires the model class!)
torch.save(model.state_dict(), MNIST_DIR / f"{MODEL_NAME}.pt")
print(f"Saved PyTorch Model State to {MNIST_DIR / MODEL_NAME}.pt")

# instantiate then load (you need to have defined NeuralNetwork)!
model_reloaded = NeuralNetwork().to(device)
model_reloaded.load_state_dict(torch.load(MNIST_DIR / f"{MODEL_NAME}.pt", weights_only=True))
```

---

# Next steps

First of all, try and test your model with your own images of numbers (or pulled from the web)!

## Plotting the evolution of your model

Here's a function that allows you to plot data from your history object.

```python
def plot_training(train_losses, train_accuracies, test_losses=None, test_accuracies=None):

    fig, axes = plt.subplots(2, 2, figsize=(10,6))
    
    # loss
    axes[0, 0].set_title("Training Loss")
    axes[0, 0].plot(train_losses, label="loss", c="c")
    axes[0, 0].legend()  

    # accuracy
    axes[1, 0].set_title("Train Accuracy")
    axes[1, 0].plot(train_accuracies, label="accuracy", c="m")
    axes[1, 0].legend()     

    if test_losses is not None:
        axes[0, 1].set_title("Test Loss")
        axes[0, 1].plot(test_losses, label="loss", c="c")
        axes[0, 1].legend()
    else:
        axes[0, 1].axes("off")

    if test_accuracies is not None:
        axes[1, 1].set_title("Test Accuracy")
        axes[1, 1].plot(test_accuracies, label="accuracy", c="m")
        axes[1, 1].legend()
    else:
        axes[1, 1].axes("off")    

    plt.show()

plot_training(train_losses, train_accs, test_losses, test_accs)
```

## Test on another dataset: FashionMNIST

Try **Fashion MNIST** instead, which works exactly the same way, but with items of clothing instead of numbers! (Can you modify the `matplotlib` code to display the correct class name?)

```python
train_data = tv.datasets.FashionMNIST(
    root=DATASETS_DIR,
    train=True,
    download=True,
    transform=transforms
)
test_data = tv.datasets.FashionMNIST(
    root=DATASETS_DIR,
    train=False,
    download=True,
    transform=transforms
)
# have a look at the classes:
print(train_data.classes)
```

Many more datasets [here](https://pytorch.org/vision/main/datasets.html#image-classification).

## Note: validation and test sets

Technically, if you wanted to optimise your net, you would not use the `test_data` in the training loop, and instead split `train_data` once more, like so:

```python
partial_train_data, validation_data = torch.utils.data.random_split(train_data, [.9,.1])
print(len(partial_train_data), len(validation_data))
```

Then you would create three `DataLoader`s, and reserve the test data only for testing your model at the very end (a bit like what we did with the saved picture).

```python
partial_train_dataloader = torch.utils.data.DataLoader(partial_train_data, batch_size=BATCH_SIZE)
validation_dataloader = torch.utils.data.DataLoader(validation_data, batch_size=BATCH_SIZE)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)
```

## Extra: A Mini ConvNet

For those who feel like exploring the Deep, here's how you would go about replacing the fully connected network above by a small ConvNet:

A Convnet will not need images to be flattened, but it will need a channel dimension.

1. Change your model definition:

```python
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()

        # 1 input channel, 32 output channels, 3x3 kernel, (default: stride 1, padding 0)
        self.conv1 = nn.Conv2d(INPUT_SHAPE[0], 32, 3)
        self.pool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool2 = nn.MaxPool2d(2)

        # this is not automatic: either you build your net
        # gradually and print the shapes, or you use the conv & maxpool formulas...
        self.flat_dim = 64 * 5 * 5 # 64 filters, channels of 5x5
        
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(self.flat_dim, NUM_CLASSES)

    def forward(self, x, verbose=False):
        x = F.relu(self.conv1(x))
        if verbose: print(x.size())
        x = self.pool1(x)
        if verbose: print(x.size())
        
        x = F.relu(self.conv2(x))
        if verbose: print(x.size())
        x = self.pool2(x)
        if verbose: print(x.size())
        
        x = x.view(-1, self.flat_dim) # this works for (1,28,28) or (1,1,28,28) 
        if verbose: print(x.size())
        x = self.dropout(x)
        
        x = self.fc(x)
        if verbose: print(x.size())
        return x

model = ConvNet().to(device)

# passing random data through our net allows us to print intermediate sizes
x = model(torch.randn((*INPUT_SHAPE)).to(device), verbose=True)
```

Voilà! And lastly:

2. The `ConvNet` is designed so that you can pass either one image `(1,28,28)` or a batch `(1,1,28,28)` without a problem.

3. Save with a different name:

With JIT:
```python
MODEL_NAME = "convnet_mnist"
# save (reload using torch.jit.load)
torch.jit.save(torch.jit.script(model), MNIST_DIR / f"{MODEL_NAME}_scripted.pt")
```

Or the weights:
```python
# save (reload using model.load_state_dict, requires the model class!)
torch.save(model.state_dict(), MNIST_DIR / f"{MODEL_NAME}.pt)
print(f"Saved PyTorch Model State to {MNIST_DIR / MODEL_NAME}")
```

### ConvNet notes

`nn.Conv2d` ([docs](https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html)) takes as arguments:

- `in_channels`: in our case 1, but could be 3 or 4 for colour images.
- `out_channels`: how many kernels/filters we want.
- `kernel_size`: defines your kernel, aka filter, by specifying the height and width of the matrix 'window' we slide over the image to detect features. Changing these sizes will affect the size of the next layer!
- Other arguments include `stride`, `padding`, `padding_mode`...

`nn.MaxPool2d` ([docs](https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html)) is also a process of sliding through the input, and at each step takes only the maximum value. This is used to downsample!

A deep CNN has convolutional layers stacked on top of each other. Each layer is made up of lots of different feature extractors, responding to different kinds of patterns. The output(s) of one layer becomes the input(s) to the next one.

- Here the flattening happens using [`view()`](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html), which is efficient and very idiomatic in PyTorch. We flatten the output of the convolutional layers (of shape `(batch_size, channels, w, h)`) to create a single long feature vector `(batch_size, features)`.
- [`Dropout`](https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html#torch.nn.Dropout) randomly sets input units to 0 with a frequency of rate at each step during training time, which helps prevent overfitting.

## Extra Extra: Training on a custom dataset!

Provided that you have images in a folder like this:
```bash
main_directory/
...class_a/
......image_1.jpg
......image_2.jpg
...class_b/
......image_1.jpg
......image_2.jpg
```

You can then replace the data loading by

```python
# Model / data parameters
NUM_CLASSES = # your number of classes

transforms = v2.Compose([  
    tv.transforms.Grayscale(num_output_channels=1),
    tv.transforms.Resize(size=(28,28), antialias=True)
])

custom_data = tv.datasets.ImageFolder(
    DATASETS_DIR / "custom_dataset",
    transform=transforms,
)

print(custom_data)
print("\n".join(custom_data.classes)) # should show the folder names

train_data, test_data = torch.utils.data.random_split(custom_data, [.9,.1])
print(len(train_data), len(test_data))

train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE)
```

See [the documentation](https://pytorch.org/vision/main/generated/torchvision.datasets.ImageFolder.html#torchvision.datasets.ImageFolder).

Checking the contents, as well as training and testing your net, should be identical as before.