# MLP in PyTorch to classify MNIST

# Setup

## Imports



- [`torch`](https://pytorch.org/docs/stable/torch.html) is the basic PyTorch library, and one of the world's few software packages to have a [documentary](https://www.youtube.com/watch?v=rgP_LBtaUEc) about it.
- The [`torch.nn`](https://pytorch.org/docs/stable/nn.html) module ("module" in the Python sense of a bunch of files containing definitions and statements, not in the PyTorch sense of a [Module](https://pytorch.org/docs/stable/generated/torch.nn.Module.html)) provides classes for layers like `Linear`, activation functions like `ReLU`, and loss functions like `CrossEntropyLoss`.
- The [`torch.optim`](https://pytorch.org/docs/stable/optim.html) module contains optimization algorithms like `Adam`.
- The [`torchvision.datasets`](https://pytorch.org/vision/stable/datasets.html) subpackage makes it easy to download a number of common datasets, including MNIST.
- The [`torchvision.transforms`](https://pytorch.org/vision/0.9/transforms.html#torchvision-transforms) subpackage makes it easy to pre-process data, such as by converting it to tensors and normalizing.
- The [`torch.utils.data.DataLoader`](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) class batches data.
- [`tqdm`](https://tqdm.github.io/) helps make progress bars, e.g. for training and evaluation loops.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

## Device

A >97% accuracy classifier can be trained in less than 5 minutes on CPU. But hey, if you have the compute to spare.

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyperparameters


We:

- use a learning rate of `0.001`,
- train for `10` epochs,
- load `64` image-label pairs in a single batch, and
- have two hidden layers producing `128`- and `64`-long activation vectors, respectively, in line with the common rule of thumb to decrease gradually from input size to output size.

In [3]:
learning_rate = 0.001
num_epochs = 10
batch_size = 64
hidden_sizes = [128, 64]

# Transforming MNIST data


- First, we use `ToTensor()` to convert the binary format MNIST data is [originally in](https://yann.lecun.com/exdb/mnist/#:~:text=FILE%20FORMATS%20FOR%20THE%20MNIST%20DATABASE) to tensors. `ToTensor()` also linearly scales pixel values to fall in the range [0, 1].
- Next, we apply [normalization](https://www.datacamp.com/tutorial/normalization-in-machine-learning): we subtract the mean pixel value of the rescaled MNIST training set (`0.1307`) and divide by the standard deviation in the pixel values (`0.3081`).

### Aside: why we normalize input data

[Z-score normalizing](https://en.wikipedia.org/wiki/Standard_score) input data helps training converge faster. To see how, consider that in simple linear layer in a neural network, $y = Wx + b$. The gradient of the loss $L$ with respect to the weight $W_{ij}$ is

$$\frac{\partial L}{\partial W_{ij}} = \frac{\partial L}{\partial y_j} x_i,$$

where

- $x_i$ is the $i$th input (recall that in the case of an MLP-based MNIST classifier, each image $x$ can be thought of as a single vector with each entry $x_i$ representing a pixel value; in other words, the 2D structure of the image is destroyed &mdash; this is an unfortunate loss of valuable information that [CNNs](https://en.wikipedia.org/wiki/Convolutional_neural_network) were designed to avoid),
- $y_j$ is the $j$th output of the layer, and
- $\frac{\partial L}{\partial y_j}$ is the gradient flowing back from the subsequent layers.


The problem is that most $x_i$s are zero &mdash; a digit picture is mostly blank space, with only a handful of positive-valued pixels. You can see this both in the picture below, and from the fact that the mean pixel value is `0.1307` (i.e., much closer to 0 than to 1).

<img src="https://drive.google.com/uc?export=view&id=1qmcSeHEdDzhkb1VDjp9SVW9uvy2i0LS6" width="400" alt="Visualizing a few digits from the MNIST dataset">

*Image credit: [Activeloop](https://datasets.activeloop.ai/docs/ml/datasets/mnist/)*

For pixels close to 0, corresponding weight updates will be smaller relative to pixels with larger values, assuming that $\frac{\partial L}{\partial y_j}$ is the same. After normalizing $x_i$s, the difference in gradient magnitudes between background pixels and digit pixels is less extreme.

In [4]:
# These constants represent statistics about MNIST pixel values, *after* the
# values have been rescaled to be in the range [0, 1] by ToTensor()
RESCALED_MNIST_MEAN = 0.1307
RESCALED_MNIST_STD = 0.3081

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((RESCALED_MNIST_MEAN,), (RESCALED_MNIST_STD,))
])

# Loading data


A few notes:

- `shuffle` is set to `True` for training data, meaning that data is reshuffled at every epoch, and `False` for test data. This is because we don't want the model to learn any spurious patterns in the *order* in which training data is presented. Test data is not shuffled to keep it consistent across evaluations.
- Setting `download` to `True` downloads the *entire* MNIST dataset, so we don't need to pass `True` to the `download` argument for test data.

In [5]:
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 17698622.93it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 489217.74it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 4421244.17it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 13725164.82it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






Let's see what the first training image looks like as a tensor. It happens to be the digit `5`, and we can see that most pixels are the same negative value (`-0.4242`), indicating a background pixel.

In [6]:
print("Label first image:", train_dataset[0][1])
print("Shape of first image:", train_dataset[0][0].shape)
print("Tensor representation of first image (z-score normalized):\n", train_dataset[0][0])

Label first image: 5
Shape of first image: torch.Size([1, 28, 28])
Tensor representation of first image (z-score normalized):
 tensor([[[-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242],
         [-0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242,
          -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4242, -0.4

# Defining the neural network

We create a simple fully-connected neural network with `len(hidden_sizes)` layers. An activation function, `ReLU`, is applied after each layer, except the last.

A weights matrix is a like an "adapter" downsizing a vector. First from a `28 * 28 = 784`-long vector to a `128`-long vector, then `64`-long, and finally a `10`-long output vector.

The final output is a logit vector. Each element (logit) in this vector corresponds to a score for each class (the digits 0-9 for MNIST). We can do a few different things with this logit vector:

- To get the model's "prediction", we can pick the largest logit. In other words, we pick the digit the model has assigned the largest "score" to.
- To translate these scores to the model's probability estimates for each digit, we can apply the [softmax](https://en.wikipedia.org/wiki/Softmax_function) function, i.e., $\sigma(\mathbf{x})_i = \frac{e^{x_i}}{\sum_{i=0}^{N} e^{x_i}}$ for an $\mathbf{x}$ with $N$ elements (exponentiating makes a term positive, and normalizing makes the terms sum to 1). As the linked Wikipedia page explains, "[t]he term 'softmax' derives from the amplifying effects of the exponential on any maxima in the input vector. For example, the standard softmax of $(1,2,8)$ is approximately $(0.001, 0.002, 0.997)$, which amounts to assigning almost all of the total unit weight in the result to the position of the vector's maximal element (of 8)." So, if you sampled from this distribution, you'd *almost* certainly be picking the maxima. However, the softmax is "softer" than just going straight for the maxima, as 3Blue1Brown [explains](https://www.youtube.com/watch?v=wjZofJX0v4M&t=1342s).
- To get log probabilities, we would apply a [log-softmax](https://datascience.stackexchange.com/questions/40714/what-is-the-advantage-of-using-log-softmax-instead-of-softmax) function.  

In [7]:
# MLP Model
class MLP(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_classes):
        super().__init__()
        self.layers = nn.ModuleList()

        # Input layer
        self.layers.append(nn.Linear(input_size, hidden_sizes[0]))
        self.layers.append(nn.ReLU())

        # Hidden layers
        for i in range(len(hidden_sizes) - 1):
            self.layers.append(nn.Linear(hidden_sizes[i], hidden_sizes[i+1]))
            self.layers.append(nn.ReLU())

        # Output layer, no ReLU
        self.layers.append(nn.Linear(hidden_sizes[-1], num_classes))

    def forward(self, x):
        # Flatten the image, setting start_dim to 1 to preserve batch dimension
        out = x.flatten(start_dim=1)
        for layer in self.layers:
            out = layer(out)
        return out

In [8]:
input_size = 28 * 28  # MNIST images are 28x28 pixels
num_classes = 10  # MNIST has 10 classes (digits 0-9)
model = MLP(input_size, hidden_sizes, num_classes).to(device)

# Loss function and optimizer

We use the [cross-entropy loss function](https://en.wikipedia.org/wiki/Cross-entropy), which is standard for classification tasks, and the [Adam optimizer](https://arxiv.org/abs/1412.6980), which incorporates momentum and adaptive gradients to provide smooth gradient updates and parameter-specific learning rates.

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop

A few notes:

- We call `optimizer.zero_grad()` to avoid gradients accumulating across all minibatches.
- We wrap iterables with `tqdm()`, specifying the total number of iterations where necessary, to generate progress bars.
- Every 100 batches, we print the loss.

In [10]:
total_batches = len(train_loader)
for epoch in tqdm(range(num_epochs)):
    for i, (imgs, labels) in tqdm(enumerate(train_loader), total=total_batches):
        imgs = imgs.to(device)
        labels = labels.to(device)

        # Forward pass
        out = model(imgs)
        loss = criterion(out, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/{total_batches}], Loss: {loss.item():.4f}')

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [1/10], Batch [100/938], Loss: 0.4422
Epoch [1/10], Batch [200/938], Loss: 0.2979
Epoch [1/10], Batch [300/938], Loss: 0.2692
Epoch [1/10], Batch [400/938], Loss: 0.1450
Epoch [1/10], Batch [500/938], Loss: 0.0789
Epoch [1/10], Batch [600/938], Loss: 0.1658
Epoch [1/10], Batch [700/938], Loss: 0.2449
Epoch [1/10], Batch [800/938], Loss: 0.1543
Epoch [1/10], Batch [900/938], Loss: 0.3089


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [2/10], Batch [100/938], Loss: 0.0198
Epoch [2/10], Batch [200/938], Loss: 0.1629
Epoch [2/10], Batch [300/938], Loss: 0.2070
Epoch [2/10], Batch [400/938], Loss: 0.0925
Epoch [2/10], Batch [500/938], Loss: 0.2572
Epoch [2/10], Batch [600/938], Loss: 0.1223
Epoch [2/10], Batch [700/938], Loss: 0.1296
Epoch [2/10], Batch [800/938], Loss: 0.1030
Epoch [2/10], Batch [900/938], Loss: 0.1011


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [3/10], Batch [100/938], Loss: 0.0661
Epoch [3/10], Batch [200/938], Loss: 0.0460
Epoch [3/10], Batch [300/938], Loss: 0.0376
Epoch [3/10], Batch [400/938], Loss: 0.1011
Epoch [3/10], Batch [500/938], Loss: 0.0925
Epoch [3/10], Batch [600/938], Loss: 0.2143
Epoch [3/10], Batch [700/938], Loss: 0.0181
Epoch [3/10], Batch [800/938], Loss: 0.1592
Epoch [3/10], Batch [900/938], Loss: 0.0571


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [4/10], Batch [100/938], Loss: 0.0442
Epoch [4/10], Batch [200/938], Loss: 0.0120
Epoch [4/10], Batch [300/938], Loss: 0.0083
Epoch [4/10], Batch [400/938], Loss: 0.1431
Epoch [4/10], Batch [500/938], Loss: 0.0156
Epoch [4/10], Batch [600/938], Loss: 0.0159
Epoch [4/10], Batch [700/938], Loss: 0.0424
Epoch [4/10], Batch [800/938], Loss: 0.0160
Epoch [4/10], Batch [900/938], Loss: 0.0933


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [5/10], Batch [100/938], Loss: 0.0189
Epoch [5/10], Batch [200/938], Loss: 0.0592
Epoch [5/10], Batch [300/938], Loss: 0.1117
Epoch [5/10], Batch [400/938], Loss: 0.0157
Epoch [5/10], Batch [500/938], Loss: 0.0148
Epoch [5/10], Batch [600/938], Loss: 0.0259
Epoch [5/10], Batch [700/938], Loss: 0.0368
Epoch [5/10], Batch [800/938], Loss: 0.0751
Epoch [5/10], Batch [900/938], Loss: 0.0143


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [6/10], Batch [100/938], Loss: 0.0713
Epoch [6/10], Batch [200/938], Loss: 0.0411
Epoch [6/10], Batch [300/938], Loss: 0.0059
Epoch [6/10], Batch [400/938], Loss: 0.0186
Epoch [6/10], Batch [500/938], Loss: 0.0342
Epoch [6/10], Batch [600/938], Loss: 0.0238
Epoch [6/10], Batch [700/938], Loss: 0.0669
Epoch [6/10], Batch [800/938], Loss: 0.0202
Epoch [6/10], Batch [900/938], Loss: 0.0334


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [7/10], Batch [100/938], Loss: 0.0182
Epoch [7/10], Batch [200/938], Loss: 0.0148
Epoch [7/10], Batch [300/938], Loss: 0.0126
Epoch [7/10], Batch [400/938], Loss: 0.0273
Epoch [7/10], Batch [500/938], Loss: 0.0652
Epoch [7/10], Batch [600/938], Loss: 0.0167
Epoch [7/10], Batch [700/938], Loss: 0.0481
Epoch [7/10], Batch [800/938], Loss: 0.0121
Epoch [7/10], Batch [900/938], Loss: 0.0092


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [8/10], Batch [100/938], Loss: 0.0143
Epoch [8/10], Batch [200/938], Loss: 0.0134
Epoch [8/10], Batch [300/938], Loss: 0.0647
Epoch [8/10], Batch [400/938], Loss: 0.0261
Epoch [8/10], Batch [500/938], Loss: 0.0292
Epoch [8/10], Batch [600/938], Loss: 0.0896
Epoch [8/10], Batch [700/938], Loss: 0.0057
Epoch [8/10], Batch [800/938], Loss: 0.0364
Epoch [8/10], Batch [900/938], Loss: 0.0976


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [9/10], Batch [100/938], Loss: 0.0148
Epoch [9/10], Batch [200/938], Loss: 0.0610
Epoch [9/10], Batch [300/938], Loss: 0.0104
Epoch [9/10], Batch [400/938], Loss: 0.0017
Epoch [9/10], Batch [500/938], Loss: 0.0042
Epoch [9/10], Batch [600/938], Loss: 0.0065
Epoch [9/10], Batch [700/938], Loss: 0.0373
Epoch [9/10], Batch [800/938], Loss: 0.0464
Epoch [9/10], Batch [900/938], Loss: 0.0440


  0%|          | 0/938 [00:00<?, ?it/s]

Epoch [10/10], Batch [100/938], Loss: 0.0091
Epoch [10/10], Batch [200/938], Loss: 0.0187
Epoch [10/10], Batch [300/938], Loss: 0.0019
Epoch [10/10], Batch [400/938], Loss: 0.0086
Epoch [10/10], Batch [500/938], Loss: 0.0030
Epoch [10/10], Batch [600/938], Loss: 0.0039
Epoch [10/10], Batch [700/938], Loss: 0.0007
Epoch [10/10], Batch [800/938], Loss: 0.0066
Epoch [10/10], Batch [900/938], Loss: 0.0048


# Model evaluation

A few notes:

- We don't use layers which behave differently during inference than they do during training (like batch normalization or dropout), but call [`model.eval()`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.eval) anyway in line with best practice.
- We use the context manager [`no_grad()`](https://pytorch.org/docs/stable/generated/torch.no_grad.html) to disable gradient calculation. We do so because we are performing *inference*, so we don't need to store gradients anywhere and can save the memory and computation.
- We get logits by running our model's forward pass on the test images, using `model(imgs)`. We elicit the model's highest-probability prediction using `torch.argmax()`. Note that we do not need to use the softmax function since we don't care about the specific probabilites the model assigns to different digits.

In [11]:
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for imgs, labels in tqdm(test_loader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        out = model(imgs)
        preds = torch.argmax(out, dim=1)
        total += labels.shape[0]
        correct += (preds == labels).sum().item()

    print(f'Accuracy of the model on the 10000 test images: {100 * correct / total}%')

  0%|          | 0/157 [00:00<?, ?it/s]

Accuracy of the model on the 10000 test images: 97.76%
