# Training MNIST Classifier (MLP)

* [What is torch.nn really?](https://pytorch.org/tutorials/beginner/nn_tutorial.html)
* [Build the Neural Network](https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html#)
* [Building Models with PyTorch](https://pytorch.org/tutorials/beginner/introyt/modelsyt_tutorial.html)

In [7]:
from typing import (
    Callable
)
from torch.nn.init import (
    xavier_normal_,
    kaiming_normal_,
)
from torch.nn import (
    Module,
    Flatten,
    Linear,
    ReLU,
    GELU,
    Dropout,
    Conv2d,
    Softmax,
    CrossEntropyLoss
)
from torch.optim import (
    SGD
)
import torch
from torch.utils.data import (
    Dataset,
    DataLoader,
    random_split,
)
import torch.nn.functional as F
from torch.nn.functional import (
    cross_entropy
)
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

In [2]:
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


# Data

In [3]:
training_data: Dataset = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data: Dataset = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:24<00:00, 411631.69it/s]


Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 46832.05it/s]


Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:13<00:00, 120014.69it/s]


Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 28924.43it/s]


Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw



In [4]:
train_size = int(0.8 * len(training_data))
validation_size = len(training_data) - train_size
training_data, validation_data = random_split(training_data, [train_size, validation_size], generator=torch.Generator().manual_seed(42))

In [5]:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True, drop_last=True)
validation_dataloader = DataLoader(validation_data, batch_size=64, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True, drop_last=True)

# Model 

In [6]:
NUM_CLASSES: int = 10           # 10 digits
NUM_FEATURES: int = 28 * 28     # Number of features in the input image

In [8]:
class MNIST(Module):
    def __init__(self):
        super().__init__()
        
        self.input = Flatten()
        self.fc01 = Linear(in_features=NUM_FEATURES, out_features=512)
        kaiming_normal_(self.fc01.weight, a=0, mode='fan_in', nonlinearity='relu')
        self.act01 = ReLU()

        self.fc02 = Linear(in_features=512, out_features=512)
        kaiming_normal_(self.fc02.weight, a=0, mode='fan_in', nonlinearity='relu')
        self.act02 = ReLU()

        self.fc03 = Linear(in_features=512, out_features=NUM_CLASSES)
        xavier_normal_(self.fc03.weight)
        # nn.CrossEntropyLoss expects unnormalized logits for each class
        # self.output = Softmax()

    def forward(self, x):
        x = self.input(x)
        x = self.act01(self.fc01(x))
        x = self.act02(self.fc02(x))
        # x = self.output(self.fc03(x))
        x = self.fc03(x)

        return x

In [10]:
model = MNIST()
model.to(device)

MNIST(
  (input): Flatten(start_dim=1, end_dim=-1)
  (fc01): Linear(in_features=784, out_features=512, bias=True)
  (act01): ReLU()
  (fc02): Linear(in_features=512, out_features=512, bias=True)
  (act02): ReLU()
  (fc03): Linear(in_features=512, out_features=10, bias=True)
)

# Training Loop

In [17]:
NUM_EPOCHS = 20
LR: float = 1e-2

optimizer = SGD(model.parameters(), lr=LR)
loss_fn = CrossEntropyLoss()

In [29]:
def train(
    model: Module,
    loss_fn: Callable,
    num_epochs: int,
    train_dataloader: DataLoader,
    validation_dataloader: DataLoader,
    device: str
):
    running_loss: float = 0.0
    chunk_size: int = int(len(train_dataloader) / 5)
    
    for epoch in range(num_epochs):
        print(f"{'-' * 80}\nEpoch:[{epoch}]\n{'-' * 80}")
        
        # --------------------------------------------------------------------------------
        # Training:
        # Set the layers to the training mode. Only effective to certain layers such as 
        # BatchNorm, DropOut which behave differently between trainintg and infereence.
        # --------------------------------------------------------------------------------
        model.train(mode=True)
        for count, (x, y) in enumerate(train_dataloader):
            y_pred = model(x.to(device))
            loss = loss_fn(y_pred, y.to(device))
    
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            # print statistics
            running_loss += loss.item()
            if count % chunk_size == chunk_size-1:
                print(f'loss:{running_loss / 100:.3f}')
                running_loss = 0.0
    
        # --------------------------------------------------------------------------------
        # Evaluation:
        # Set the layers to the inference mode. Only effective to certain layers.
        # --------------------------------------------------------------------------------
        if validation_dataloader:
            model.train(False)   # same with model.eval()
            with torch.no_grad():
                validation_loss = sum(loss_fn(model(_x.to(device)), _y.to(device)) for _x, _y in validation_dataloader)
                print(f"validation loss: {(validation_loss / len(validation_dataloader)).cpu().item()}")

train(
    model=model,
    loss_fn=loss_fn,
    num_epochs=NUM_EPOCHS,
    train_dataloader=train_dataloader,
    validation_dataloader=validation_dataloader,
    device=device
)

--------------------------------------------------------------------------------
Epoch:[0]
--------------------------------------------------------------------------------
loss:0.058
loss:0.054
loss:0.050
loss:0.062
loss:0.057
validation loss: 0.09368696808815002
--------------------------------------------------------------------------------
Epoch:[1]
--------------------------------------------------------------------------------
loss:0.053
loss:0.051
loss:0.050
loss:0.061
loss:0.059
validation loss: 0.09195054322481155
--------------------------------------------------------------------------------
Epoch:[2]
--------------------------------------------------------------------------------
loss:0.048
loss:0.055
loss:0.055
loss:0.057
loss:0.051
validation loss: 0.09285635501146317
--------------------------------------------------------------------------------
Epoch:[3]
--------------------------------------------------------------------------------
loss:0.047
loss:0.063
loss:0.050
los

# Test

In [33]:
def test(
    model: torch.nn.Module, 
    test_dataloader: DataLoader,
    device: str
):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(output, target, ).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_dataloader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_dataloader.dataset),
        100. * correct / len(test_dataloader.dataset)
    ))

test(model=model,test_dataloader=test_dataloader, device=device)


Test set: Average loss: 0.0011, Accuracy: 9755/10000 (98%)



# CNN

In [41]:
class MNISTCNN(Module):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2d(1, 32, 3, 1)
        self.conv2 = Conv2d(32, 64, 3, 1)
        self.dropout1 = Dropout(0.25)
        self.dropout2 = Dropout(0.5)
        self.fc1 = Linear(9216, 128)
        self.fc2 = Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return x

In [42]:
model = MNISTCNN()
model.to(device)

MNISTCNN(
  (conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (dropout1): Dropout(p=0.25, inplace=False)
  (dropout2): Dropout(p=0.5, inplace=False)
  (fc1): Linear(in_features=9216, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=10, bias=True)
)

In [None]:
optimizer = SGD(model.parameters(), lr=1e-3)
loss_fn = CrossEntropyLoss()

train(
    model=model,
    loss_fn=loss_fn,
    num_epochs=NUM_EPOCHS,
    train_dataloader=train_dataloader,
    validation_dataloader=validation_dataloader,
    device=device
)

--------------------------------------------------------------------------------
Epoch:[0]
--------------------------------------------------------------------------------
loss:0.307
loss:0.314
loss:0.311
loss:0.292


In [None]:
test(model=model,test_dataloader=test_dataloader, device=device)