## Image Classification on MNIST dataset 😼

### Import libraries

In [None]:
import numpy as np
import torch
import torchvision
from torchvision.transforms import v2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

### Create transform for image transformation

In [None]:
# https://pytorch.org/vision/stable/auto_examples/transforms/plot_transforms_getting_started.html#i-just-want-to-do-image-classification

transforms = v2.Compose(
    [
        v2.PILToTensor(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Normalize((0.5,), (0.5,)),
    ]
)

### get dataset & pass it into `DataLoader`

In [None]:
# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.MNIST(
    "./data", train=True, transform=transforms, download=True
)
validation_set = torchvision.datasets.MNIST(
    "./data", train=False, transform=transforms, download=True
)

# Class labels
classes = ("0","1","2","3","4","5","6","7","8","9")


# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(
    validation_set, batch_size=4, shuffle=False
)
print("Training set has {} instances".format(len(training_set)))
print("Validation set has {} instances".format(len(validation_set)))

### Visualize data

In [None]:
import matplotlib.pyplot as plt
import numpy as np


# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    if one_channel:
        img = img.mean(dim=0)
    img = img / 2 + 0.5  # un-normalize
    np_img = img.numpy()
    if one_channel:
        plt.imshow(np_img, cmap="Greys")
    else:
        plt.imshow(np.transpose(np_img, (1, 2, 0)))


data_iter = iter(training_loader)
images, labels = next(data_iter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print("  ".join(classes[labels[j]] for j in range(4)))

In [None]:
print(f"{type(images)=}, {type(images[0])=}")
print(f"{images.shape=}, {images[0].shape=}")

---

### Let's start working on Classification Model 🦅

In [None]:
import torch.nn as nn
import torch.nn.functional as F


# PyTorch models inherit from torch.nn.Module
class MNISTDigitClassifier(nn.Module):
    def __init__(self):
        super(MNISTDigitClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)  # input= 1*28*28; output=6*24*24 (28-5+1)
        self.pool = nn.MaxPool2d(2, 2)  # input= 6*24*24; output=6*12*12 (24/2)
        self.conv2 = nn.Conv2d(6, 16, 5)  # input= 6*12*12; output=16*8*8 (12-5+1)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)  # after applying `self.pool`, 8/2 = 4. hence input (16*4*4)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = MNISTDigitClassifier().to(device=device)

In [None]:
loss_fn = torch.nn.CrossEntropyLoss()

# Optimizers specified in the torch.optim package
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
def train_one_epoch(epoch_index):
    running_loss = 0.0
    last_loss = 0.0

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs.to(device=device))

        # Compute the loss and its gradients
        loss = loss_fn(outputs.to(device=device), labels.to(device=device))
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000  # loss per batch
            print("  batch {} loss: {}".format(i + 1, last_loss))
            running_loss = 0.0

    return last_loss

In [None]:
# Initializing in a separate cell so we can easily add more epochs to the same run
epoch_number = 0

EPOCHS = 5

best_v_loss = 1_000_000.0

print(f"model will be using {device =}")

for epoch in range(EPOCHS):
    print("EPOCH {}:".format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)

    running_v_loss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, v_data in enumerate(validation_loader):
            v_inputs, v_labels = v_data
            v_outputs = model(v_inputs.to(device=device))
            v_loss = loss_fn(v_outputs.to(device=device), v_labels.to(device=device))
            running_v_loss += v_loss

    avg_v_loss = running_v_loss / (i + 1)
    print("LOSS train {} valid {}".format(avg_loss, avg_v_loss))

    epoch_number += 1

## Visualize model predictions 🚀

In [None]:
def visualize_model_predictions():
    # torch.manual_seed(42)

    my_idx = torch.randint(high=len(validation_set), size=(5,))
    print(f"my random indexes are: {my_idx}")

    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i in my_idx:
            curr_image, actual_label = validation_set[i]
            prediction = model(transforms(curr_image))
            predicted_label = prediction.argmax().item()
            color = "red"
            if actual_label==predicted_label:
                color="green"
            # first transform the image and then let model predict
            plt.imshow(curr_image.permute(1, 2, 0))
            plt.title(f"{actual_label=}; {predicted_label=}", color=color)
            plt.axis('off')
            plt.show()

In [None]:
visualize_model_predictions()