In [4]:
from torch import torch
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt


In [68]:
lr = 0.02
bs = 64
epochs = 10

Load Data

In [158]:
from pathlib import Path

data_dirs = [Path("./"), Path("../")]

tf = transforms.Compose([
    # 0.1307 is the mean of the MNIST dataset, 0.3081 is the standard deviation
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)), 
])

for data_dir in data_dirs:
    if (data_dir / "MNIST").exists():
        train_data = datasets.MNIST(data_dir, train=True, transform=tf)
        test_data = datasets.MNIST(data_dir, train=False, transform=tf)
        break
else:
    train_data = datasets.MNIST("./", train=True, download=True, transform=tf)
    test_data = datasets.MNIST("./", train=False, download=True, transform=tf)
    
g = torch.Generator().manual_seed(42)

train_data, val_data = random_split(train_data, [50000, 10000], generator=g)

train_loader = DataLoader(train_data, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_data, batch_size=bs*2, shuffle=True)
test_loader = DataLoader(test_data, batch_size=bs*2)

print(next(iter(train_loader))[0].shape)  # (batch_size, channels, height, width)

torch.Size([64, 1, 28, 28])


Model


In [159]:
import numpy as np
from torch import Tensor

loss_func = F.cross_entropy

def accuracy(input:Tensor, target:Tensor):
    preds = torch.argmax(input, dim=1)
    return (preds == target).float().mean()

In [None]:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1)

        self.fc1 = nn.Linear(128*5*5, 10)
       
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)

        x = x.reshape(-1, 128*5*5)
        x = self.fc1(x)

        return x
    
    def evaluate(self, loader:DataLoader): 
        """
        Evaluate the model on the given data
        return: (loss, accuracy)
        """
        self.eval()
        total_loss = 0
        total_acc = 0
        with torch.no_grad():
            for x_batch, y_batch in loader:
                preds = self.forward(x_batch)
                total_acc += accuracy(preds, y_batch)
                total_loss += loss_func(preds, y_batch)
        return total_loss / len(loader), total_acc / len(loader)


In [161]:
def get_model():
    model = CNN()
    return model, optim.SGD(model.parameters(), lr=lr, momentum=0.9)

Train

In [162]:
def fit(model, optimizer, train_loader, val_loader, epochs):
    for epoch in range(epochs):
        model.train()
        for x_batch, y_batch in train_loader:
            preds = model.forward(x_batch)
            loss = loss_func(preds, y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        loss, acc = model.evaluate(val_loader)
        print(f"epoch {epoch+1} loss: {loss:.4f}, accuracy: {acc:.4f}")

model, optimizer = get_model()
fit(model, optimizer, train_loader, val_loader, epochs)

epoch 1 loss: 0.0758, accuracy: 0.9787
epoch 2 loss: 0.0630, accuracy: 0.9808
epoch 3 loss: 0.0428, accuracy: 0.9861
epoch 4 loss: 0.0449, accuracy: 0.9870
epoch 5 loss: 0.0412, accuracy: 0.9885
epoch 6 loss: 0.0449, accuracy: 0.9890
epoch 7 loss: 0.0393, accuracy: 0.9889
epoch 8 loss: 0.0408, accuracy: 0.9897
epoch 9 loss: 0.0499, accuracy: 0.9873
epoch 10 loss: 0.0404, accuracy: 0.9902


Test

In [163]:
loss, acc = model.evaluate(test_loader)
print(f"loss: {loss:.4f}, accuracy: {acc:.4f}")

loss: 0.0286, accuracy: 0.9916


In [None]:
# save model
torch.save(model.state_dict(), "model5.pth")

# load model
model = CNN()
model.load_state_dict(torch.load("model5.pth"))
loss, acc = model.evaluate(test_loader)
print(f"loss: {loss:.4f}, accuracy: {acc:.4f}")

- Conv: convolution layer
- FC: fully connected layer
- Opt: optimizer
- LR: learning rate
- M: momentum
- E: epochs

|No.|Conv|FC|Opt|LR|M|E|Accuracy|
|-|-|-|-|-|-|-|-|
| 0 | 16->16->10 -> AvgPool (GAP) | [] | SGD | 0.2 | 0.9 | 20 | 96.69% |
| 1 | 16->16->10 | [] | SGD | 0.01 | 0 |20| 97.54% |
| 2 | 16->16->10 | [16] | SGD | 0.2 | 0 |20| 98.03% |
| 3 | 16->16->10 | [16] | SGD | 0.02 | 0.9 |20| 97.91% |
| 4 | 32->maxpool(2)->64->maxpool(2) | [16] | SGD | 0.02 | 0.9 |10| 99.05% |
| 5 | 64(5x5)->maxpool(2)->128(3x3)->maxpool(2) | [] | SGD | 0.02 | 0.9 |10| 99.16% |