# Image Classification with MNIST Dataset

In this Jupyter Notebook, we will explore the task of image classification using the MNIST dataset. The MNIST dataset is a widely-used example in the field of machine learning and computer vision, consisting of a collection of 28x28 grayscale images of handwritten digits (0 through 9). Each image is labeled with the corresponding digit it represents, making it an ideal dataset for training and testing image classification algorithms.

In this notebook, we will build and train a multilayer perceptron (MLP).
MLP is not the go-to solution for image classification, but we will use it as an example.

In [None]:
import matplotlib.pyplot as plt
%config InlineBackend.figure_format = 'svg'
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from util import show_sample, show_predictions

In [None]:
device = 0  # the GPU 0

In [None]:
# Define a simple transform
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [None]:
# Download MNIST dataset
train_dataset = datasets.MNIST(root="./data", train=True,  download=True, transform=transform)
test_dataset  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)

In [None]:
train_dataset

In [None]:
test_dataset

In [None]:
ex = 50
image_ex = train_dataset[ex][0][0]
label_ex = train_dataset[ex][1]

show_sample(image_ex, label_ex)

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=256, shuffle=True)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=512, shuffle=False, drop_last=True)

In [None]:
# Define the model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten input images
        x = self.relu1(self.fc1(x))
        x = self.relu2(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
model = MLP().to(device)

In [None]:
crossentropy = nn.CrossEntropyLoss()  # Cross Entropy Loss for classification

In [None]:
# let's try one 10 epochs with lr=0.001 and the 10 more with lr=0.0001
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
model.train()

In [None]:
num_epochs = 5                                               #
loss_history = []                                            #
                                                             #
for epoch in range(num_epochs):                              #   loop over epochs:
    for imgs, labels in train_loader:                        #       loop over batches:  -> (images, labels)
        imgs, labels = imgs.to(device), labels.to(device)    #            * copy batch to the gpu
        optimizer.zero_grad()                                #            * reset automatic differentiation record
        outputs = model(imgs)                                #            * evaluate the model in a batch
        loss = crossentropy(outputs, labels)                 #            * evaluate the loss function with the obtained outputs and labels
        loss.backward()                                      #            * backpropagation -> gradients
        optimizer.step()                                     #            * update weights with the gradients
        # [not part of the traning] keeping values for plotting
        loss_history.append(loss.cpu().detach().numpy())

    print(f'Epoch [{epoch + 1:2d}/{num_epochs}] Loss {loss:0.4f}')

In [None]:
plt.plot(loss_history, c='red')
plt.xlabel('Training steps')
plt.ylabel('Loss')
plt.grid(ls=':')
plt.show()

In [None]:
# Test the trained model
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        probabilities = F.softmax(outputs, dim=1)
        predicted_labels = torch.argmax(probabilities, dim=1)

        total += labels.size(0)
        correct += (predicted_labels == labels).sum().item()

accuracy = correct / total
print(f'Test Accuracy: {accuracy * 100:.2f}%')

In [None]:
show_predictions(imgs, predicted_labels)