# PyTorch MNIST Classifier

In this notebook I will implement a deep convolutional neural network to classify handwritten digits from the MNIST dataset.

Import all the necessary packages:

In [16]:
import torch
import torchvision
from torch.utils.data import DataLoader, Subset
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

Define the validation, train, and test subsets:

In [17]:
image_path = './'
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

mnist_dataset = torchvision.datasets.MNIST(
    root=image_path,
    train=True,
    transform=transform,
    download=False
)

mnist_valid_dataset = Subset(
    mnist_dataset,
    torch.arange(10_000)
)

mnist_train_dataset = Subset(
    mnist_dataset,
    torch.arange(
        10_000, len(mnist_dataset) 
    )
)

mnist_test_dataset = torchvision.datasets.MNIST(
    root=image_path,
    train=False,
    transform=transform,
    download=False
)

Set up data loaders:

In [18]:
batch_size = 64
torch.manual_seed(1)

train_dl = DataLoader(
    mnist_train_dataset,
    batch_size,
    shuffle=True
)

valid_dl = DataLoader(
    mnist_valid_dataset,
    batch_size,
    shuffle=False
)

In [19]:
class DeepConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential()
        self.model.add_module(
            'conv1',
            nn.Conv2d(
                in_channels=1, out_channels=32,
                kernel_size=5, padding=2
            )
        )
        self.model.add_module('relu1', nn.ReLU())
        self.model.add_module('pool1', nn.MaxPool2d(kernel_size=2))
        self.model.add_module(
            'conv2',
            nn.Conv2d(
                in_channels=32, out_channels=64,
                kernel_size=5, padding=2
            )
        )
        self.model.add_module('relu2', nn.ReLU())
        self.model.add_module('pool2', nn.MaxPool2d(kernel_size=2))
        self.model.add_module('flatten', nn.Flatten())
        self.model.add_module('fc1', nn.Linear(3136, 1024))
        self.model.add_module('relu3', nn.ReLU())
        self.model.add_module('dropout', nn.Dropout(p=0.5))
        self.model.add_module('fc2', nn.Linear(1024, 10))

    def forward(self, x):
        return self.model(x)

In [20]:
my_net = DeepConvNet()

In [21]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(my_net.parameters(), lr=1e-3)

In [None]:
def train(model, n_epochs, train_dl, valid_dl):
    loss_history_train = [0] * n_epochs
    loss_history_valid = [0] * n_epochs
    accuracy_history_train = [0] * n_epochs
    accuracy_history_valid = [0] * n_epochs

    for epoch in range(n_epochs):
        model.train()

        for x_batch, y_batch in train_dl:
            optimizer.zero_grad()
            prediction = model(x_batch)
            loss_value = loss_fn(prediction, y_batch)
            loss_value.backward()
            optimizer.step()

            is_correct = (
                torch.argmax(prediction, dim=1) == y_batch()
            ).float()

            accuracy_history_train[epoch] += is_correct().sum()
        
        loss_history_train[epoch] /= len(train_dl.dataset)
        accuracy_history_train[epoch] /= len(train_dl.dataset)

        model.eval()

        with torch.no_grad():
            for x_batch, y_batch in valid_dl:
                prediction = model(x_batch)
                loss_value = loss_fn(prediction, y_batch)

                loss_history_valid[epoch] += loss_value.item() * y_batch.size(0)

                is_correct = (
                    torch.argmax(prediction, dim=1) == y_batch
                ).float()
                accuracy_history_valid[epoch] += is_correct.sum()
        loss_history_valid[epoch] /= len(valid_dl.dataset)
        accuracy_history_valid[epoch] /= len(valid_dl.dataset)

    return loss_history_train, loss_history_valid,\
           accuracy_history_train, accuracy_history_valid 