In [1]:
from torch import torch
from torchvision import datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt


Load Data

In [2]:
from pathlib import Path

data_dirs = [Path("./"), Path("../")]

tf = transforms.Compose([
    # 0.1307 is the mean of the MNIST dataset, 0.3081 is the standard deviation
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)), 
])

for data_dir in data_dirs:
    if (data_dir / "MNIST").exists():
        train_data = datasets.MNIST(data_dir, train=True, transform=tf)
        test_data = datasets.MNIST(data_dir, train=False, transform=tf)
        break
else:
    train_data = datasets.MNIST("./", train=True, download=True, transform=tf)
    test_data = datasets.MNIST("./", train=False, download=True, transform=tf)
    
g = torch.Generator().manual_seed(42)

train_data, val_data = random_split(train_data, [50000, 10000], generator=g)

Model


In [3]:
import numpy as np
from torch import Tensor

loss_func = F.cross_entropy

def accuracy(input:Tensor, target:Tensor):
    preds = torch.argmax(input, dim=1)
    return (preds == target).float().mean()

In [4]:
def fit(model, optimizer, train_loader, val_loader, epochs):
    for epoch in range(epochs):
        model.train()
        for x_batch, y_batch in train_loader:
            preds = model.forward(x_batch)
            loss = loss_func(preds, y_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        
        loss, acc = model.evaluate(val_loader)
        print(f"epoch {epoch+1} loss: {loss:.4f}, accuracy: {acc:.4f}")


In [41]:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(start_dim=1),
            nn.Linear(64*5*5, 10),
        )
       
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x
    
    def evaluate(self, loader:DataLoader): 
        """
        Evaluate the model on the given data
        return: (loss, accuracy)
        """
        self.eval()
        total_loss = 0
        total_acc = 0
        with torch.no_grad():
            for x_batch, y_batch in loader:
                preds = self.forward(x_batch)
                total_acc += accuracy(preds, y_batch)
                total_loss += loss_func(preds, y_batch)
        return total_loss / len(loader), total_acc / len(loader)

In [42]:
lr = 0.02
bs = 64
epochs = 10

train_loader = DataLoader(train_data, batch_size=bs, shuffle=True)
val_loader = DataLoader(val_data, batch_size=bs*2, shuffle=True)
test_loader = DataLoader(test_data, batch_size=bs*2)

def get_model():
    model = CNN()
    return model, optim.SGD(model.parameters(), lr=lr, momentum=0.9)

Train

In [None]:
model, optimizer = get_model()
print(next(model.named_modules())[1])
print(f"Params: {sum(p.numel() for p in model.parameters())}")

fit(model, optimizer, train_loader, val_loader, epochs)

- Conv: convolution layer
- FC: fully connected layer
- Opt: optimizer
- LR: learning rate
- M: momentum
- E: epochs
- Params: number of parameters
- Acc: Accuracy on validation dataset

|No.|Conv|FC|Opt|LR|M|E|Params|Acc|
|-|-|-|-|-|-|-|-|-|
| 4 | 32(3x3)->maxpool(2)->64(3x3)->maxpool(2) | [16] | SGD | 0.02 | 0.9 |10| 44602 | 98.72% |
| 5 | 64(5x5)->maxpool(2)->128(3x3)->maxpool(2) | [] | SGD | 0.02 | 0.9 |10|107530| 98.90% |
| 6 | 32(5x5)->maxpool(2)->64(3x3)->maxpool(2) | [16] | SGD | 0.02 | 0.9 |10|45114 | 98.88% |
| 7 | 32(5x5)->maxpool(2)->64(3x3)->maxpool(2) | [] | SGD | 0.02 | 0.9 |10|35338 | 98.98% |
| 8 | 32(3x3)->maxpool(2)->64(3x3)->maxpool(2) | [] | SGD | 0.02 | 0.9 |10|34826 | 98.94% |
| 9 | 16(5x5)->maxpool(2)->32(3x3)->maxpool(2) | [] | SGD | 0.02 | 0.9 |10|13066 | 98.61% |

Save & Test

In [None]:
# save model
model_path = Path(f"{input(">> ")}.pth")
if not model_path.exists():
    print(next(model.named_modules())[1])
    print(f"Params: {sum(p.numel() for p in model.parameters())}")
    torch.save(model.state_dict(), model_path)
else:
    print("File already exists")

# load model
model = CNN()

model.load_state_dict(torch.load(model_path))
loss, acc = model.evaluate(test_loader)
print(f"loss: {loss:.4f}, accuracy: {acc:.4f}")