<a href="https://colab.research.google.com/github/gileshall/ML-Biology-Notebooks/blob/main/PyTorch_MNIST_tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Goal: Classify hand-written digits from the MNIST dataset.

The MNIST dataset is a collection of grayscale images of handwritten digits (0-9), each 28x28 pixels. The goal is to train a neural network to accurately identify the digit represented in each image. This problem is fundamental to computer vision and serves as a classic example for beginners to understand basic deep learning concepts like image preprocessing, neural network architecture, and model training.

The MNIST dataset was created by combining two datasets from the National Institute of Standards and Technology (NIST). It was designed as a simple, accessible benchmark for evaluating image recognition models. It quickly became the "Hello World" of deep learning, providing a foundational dataset that is easy to understand but effective for demonstrating neural networks. Despite its simplicity, MNIST played a key role in advancing research and continues to be a popular tool for demonstrating machine learning techniques.

The solution involves creating a simple feed-forward neural network to classify MNIST digits implemented using PyTorch. It has several fully connected layers that learn to recognize patterns in the digits. By feeding the network labeled images during training, it learns to distinguish between the digits through optimization of its parameters.

This is a good toy example because it encapsulates many core components of deep learning in a simple way.

In [None]:
!pip install -q ipycanvas orjson

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import PIL
import numpy as np
from google.colab import output
output.enable_custom_widget_manager()

from ipywidgets import Image
from ipywidgets import ColorPicker, IntSlider, link, AppLayout, HBox
from ipycanvas import Canvas, hold_canvas

In [None]:
# Download and load MNIST hand written digit dataset

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)


In [None]:
# Visualization of the training data

for i_exp in range(10):
    img = train_dataset[i_exp][0]
    img = transforms.ToPILImage()(img)
    display(img)

## Defining the Neural Network Model

This is how to define a custom neural network containing three fully connected layers. The input layer takes 28 x 28 pixels from the flattened MNIST images, followed by two hidden layers with 128 and 64 nodes respectively. The output layer has 10 nodes, each representing one of the digit classes (0-9). The forward propagation function shows how the data moves through the network.

Model Architecture:

# Neural Network Model for MNIST Handwritten Digit Classification

This neural network is a **fully connected (feedforward) neural network** implemented in PyTorch. It aims to classify handwritten digits (0–9) from the **MNIST dataset**, which consists of 28x28 grayscale images.

## Model Architecture

- **Input Layer:**
  - Each image in the MNIST dataset is of size **28x28 pixels**.
  - To feed this into the model, it is flattened into a **1D vector of size 784 (28 × 28)** using `x.view(-1, 28 * 28)`.

- **First Hidden Layer (fc1):**
  - Takes the **784-element vector** from the input layer as input.
  - Contains **128 neurons**.
  - Applies the **ReLU (Rectified Linear Unit)** activation function.

- **Second Hidden Layer (fc2):**
  - Takes the **128-element vector** from the first hidden layer as input.
  - Contains **64 neurons**.
  - Uses the **ReLU activation function** again.

- **Output Layer (fc3):**
  - Takes the **64-element vector** from the second hidden layer as input.
  - Outputs a **10-element vector**, where each element corresponds to one of the **10 classes (digits 0-9)**.

Here is a visualization of a fully connected network.

![A fully connected network](http://www.gabormelli.com/RKB/images/4/44/2layersFCNN.png "A fully connected net")

Here is a visualization of a ReLU activation function:

![sigmoid and ReLU](https://miro.medium.com/v2/resize:fit:1400/format:webp/1*aEVZlqkcakVySV6ETqgfEg.png "Sigmoid and ReLU")

In [None]:
# Device configuration: Determines whether the model will run on GPU or CPU.
# If a CUDA-compatible GPU is available, it will use 'cuda' for faster computation;
# otherwise, it will fall back to 'cpu'.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Neural Network Model
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Instantiate the new model
model = NeuralNet().to(device)

In [None]:
# Hyperparameters: These are key parameters that control the learning process of the model.
# Adjusting them can significantly affect the model's performance.

# Number of samples processed before the model updates its weights.
batch_size = 32

# Step size used by the optimizer to update model parameters.
learning_rate = 0.01

# Total number of complete passes through the entire dataset.
epochs = 15

# build the data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Loss and optimizer
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)

# Training loop
#
# Mode	    Dropout	    BatchNorm	    Gradients
# --------------------------------------------------
# train()	Active	    Batch-wise	    Computed
# eval()	Inactive	Running stats	Not computed

for epoch in range(epochs):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        # Forward pass
        outputs = model(data)
        loss = loss_function(outputs, target)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')


In [None]:
# Test the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')


In [None]:
width = 255
height = 255

canvas = Canvas(width=width, height=height, sync_image_data=True)

drawing = False
position = None
img_data = [None]

def recrop(img):
    bbox = img.getbbox()
    img = img.crop(bbox)
    sq_len = max(img.size) + 32
    left_offset = (sq_len - img.size[0]) // 2
    top_offset = (sq_len - img.size[1]) // 2
    pad_img = PIL.Image.new(img.mode, (sq_len, sq_len), color=(0,))
    pad_img.paste(img, (left_offset, top_offset))
    return pad_img

def predict():
    image = PIL.Image.fromarray(canvas.get_image_data())
    image = image.convert('L')
    image = recrop(image)
    # Resize to 28x28 and transform
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((28, 28)),
        transforms.Normalize((0.5,), (0.5,))
    ])
    image = transform(image).unsqueeze(0).to(device)
    # show what's going into the network
    display(transforms.ToPILImage()(image[0]))

    # Predict
    model.eval()
    with torch.no_grad():
        output = model(image)

    probabilities = F.softmax(output.data, dim=1)
    _, predicted = torch.max(output.data, 1)
    probabilities = [f'[{i}: {p * 100:.2f}%]' for (i, p) in enumerate(probabilities[0].cpu().numpy())]
    print(f'Probabilities: {" ".join(probabilities)}]')
    print(f'Predicted: {predicted.item()}')

def init_canvas():
    canvas.fill_style = '#000000'
    canvas.fill_rect(0, 0, canvas.width, canvas.height)
    canvas.fill_style = '#FFFFFF'
    canvas.stroke_style = "#FFFFFF"
    canvas.line_width = 8

def on_mouse_down(x, y):
    global drawing
    global position

    drawing = True
    position = (x, y)

def on_mouse_move(x, y):
    global drawing
    global position

    if not drawing:
        return

    with hold_canvas():
        canvas.fill_circle(x, y, 7)

def on_mouse_up(x, y):
    global drawing
    global position

    drawing = False

    with hold_canvas():
        canvas.fill_circle(x, y, 7)

def on_keyboard_event(key, shift_key, ctrl_key, meta_key):
    if key.lower() == 'c':
        canvas.clear()
        init_canvas()
    elif key.lower() == 'p':
        predict()

canvas.on_key_down(on_keyboard_event)
canvas.on_mouse_down(on_mouse_down)
canvas.on_mouse_move(on_mouse_move)
canvas.on_mouse_up(on_mouse_up)

init_canvas()
print("Draw a digit with your mouse.\nPress 'p' to predict.\nPress 'c' to clear.")
canvas

## Exercise: Try a Convolutional Neural Network (CNN)

Here is an alternative architecture for you to try. This model uses **Convolutional Neural Networks (CNNs)**, which are especially effective for image classification tasks like recognizing handwritten digits from the MNIST dataset.

In this exercise, you can:
- Implement this CNN and compare its performance with the fully connected network.
- Experiment with the architecture (e.g., change the number of filters, dropout rates, or add more layers).
- Train the model on the MNIST dataset and observe how CNNs handle image data more efficiently.

To use this model, execute this cell and scroll back up to the training cell.  Comment out the line that creates the NeuralNet model and uncomment the line that creates the CNN model.

In [None]:
# Define the Convolutional Neural Network
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.dropout(x)
        x = x.view(-1, 64 * 7 * 7)  # Flatten the tensor
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x


# Alternative model (see below)
model = CNN().to(device)