In [1]:
from torchvision.datasets.mnist import MNIST
from torchvision.transforms import ToTensor, Compose, Normalize

transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST(root='data', train=True, transform=transform, download=True)
test_dataset = MNIST(root='data', train=False, transform=transform, download=True)

In [2]:
from torch import Tensor
from torch import nn

class Perceptron(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_layer = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=784, out_features=512),
            nn.Dropout(p=0.2),
            nn.ReLU(),
        )

        self.hidden_layer = nn.Sequential(
            nn.Linear(in_features=512, out_features=128),
            nn.Dropout(p=0.2),
            nn.ReLU(),
        )

        self.output_layer = nn.Sequential(
            nn.Linear(in_features=128, out_features=10),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layer(x)
        x = self.output_layer(x)
        return x

In [43]:
from typing import Protocol

import torch
from torch.optim import Optimizer

class Module(Protocol):

    def __call__(self, *args, **kwargs) -> Tensor:
        ...

class Criterion(Protocol):

    def __call__(self, output : Tensor, target : Tensor) -> Tensor:
        ...

class Model:
    def __init__(self, main : Module, criterion : Criterion, optimizer : Optimizer):
        self.module = torch.compile(main)
        self.criterion = criterion
        self.optimizer = optimizer

    def fit(self, input : Tensor, target : Tensor) -> float:
        self.optimizer.zero_grad(set_to_none=True)
        output = self.module(input)
        loss = self.criterion(output, target)
        loss.backward()
        self.optimizer.step()
        return loss.item()
    
    @torch.no_grad
    def predict(self, input : Tensor) -> Tensor:
        return self.module(input)

In [47]:
import time
from typing import Protocol
from typing import Iterator
from typing import Tuple

class Data(Protocol):
    def __iter__(self) -> Iterator[Tuple[Tensor, Tensor]]:
        ...        

def train(model : Model, data : Data, device : str = 'cuda'):
    for batch, (input, target) in enumerate(data, start=1):
        input = input.to(device)
        target = target.to(device)
        loss = model.fit(input, target)
        print(f'Batch {batch}, Loss {loss:.6f}')

@torch.no_grad
def test(model : Model, data : Data, device : str = 'cuda'):
    for batch, (input, target) in enumerate(data, start=1):
        input = input.to(device)
        target = target.to(device)
        output = model.predict(input)

In [48]:
import torch
from torch.optim import Adam
from torch.nn import NLLLoss
from torch.utils.data import DataLoader

torch.set_float32_matmul_precision('high')
module = Perceptron().to('cuda')
criterion = NLLLoss().to('cuda')
optimizer = Adam(params=module.parameters(), lr=0.01)

model = Model(
    main=module, 
    criterion=criterion, 
    optimizer=optimizer
)

train_dataloader = DataLoader(
    dataset=train_dataset, 
    batch_size=64, 
    shuffle=True, 
    pin_memory=True,
    pin_memory_device='cuda',
    num_workers = 5
)

test_dataloader = DataLoader(
    dataset=test_dataset, 
    batch_size=64, 
    shuffle=False, 
    pin_memory=True,
    pin_memory_device='cuda',
    num_workers = 5
)

In [49]:
for epoch in range(3):
    train(model, train_dataloader)
    test(model, train_dataloader)

Batch 1, Loss 2.314561
Batch 2, Loss 5.035521
Batch 3, Loss 2.876099
Batch 4, Loss 2.041653
Batch 5, Loss 2.128000
Batch 6, Loss 1.494551
Batch 7, Loss 1.453058
Batch 8, Loss 1.305101
Batch 9, Loss 1.035574
Batch 10, Loss 1.097038
Batch 11, Loss 1.161539
Batch 12, Loss 0.966613
Batch 13, Loss 1.228911
Batch 14, Loss 1.250885
Batch 15, Loss 0.933560
Batch 16, Loss 1.240801
Batch 17, Loss 0.706799
Batch 18, Loss 0.753451
Batch 19, Loss 0.663032
Batch 20, Loss 0.508619
Batch 21, Loss 0.723610
Batch 22, Loss 0.676208
Batch 23, Loss 0.789115
Batch 24, Loss 0.851832
Batch 25, Loss 0.872076
Batch 26, Loss 0.567173
Batch 27, Loss 0.503984
Batch 28, Loss 0.809376
Batch 29, Loss 0.664953
Batch 30, Loss 0.809007
Batch 31, Loss 0.487351
Batch 32, Loss 0.493864
Batch 33, Loss 0.458856
Batch 34, Loss 0.616037
Batch 35, Loss 0.683431
Batch 36, Loss 0.645175
Batch 37, Loss 0.413145
Batch 38, Loss 0.397842
Batch 39, Loss 0.711137
Batch 40, Loss 0.327342
Batch 41, Loss 0.249991
Batch 42, Loss 0.522269
B