In [None]:
from torch import cuda

device = 'cuda' if cuda.is_available() else 'cpu'
print(f'device: {device}')

In [None]:
from torchvision.datasets import EMNIST
from torchvision.transforms import ToTensor

train_data = EMNIST(
    root = 'data',
    split = 'digits',
    train = True,
    download = True,
    transform = ToTensor())

test_data = EMNIST(
    root = 'data',
    split = 'digits',
    train = False,
    download = True,
    transform = ToTensor())

In [None]:
from torch.utils.data import DataLoader

batch_size = 2 ** 14
train_dataloader = DataLoader(train_data, batch_size = batch_size, shuffle = True)
test_dataloader = DataLoader(test_data, batch_size = batch_size, shuffle = True)

In [None]:
import numpy as np

x: np.ndarray
y: np.ndarray
for x, y in test_dataloader:
    print(f'shape of x [n, c, h, w]: {x.shape} {x.dtype}')
    print(f'shape of y: {y.shape} {y.dtype}')
    break

In [None]:
from torch import Tensor
from torch.nn import Conv2d
from torch.nn import Linear
from torch.nn import LogSoftmax
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ReLU
from torch import flatten

class LigmaNetwork(Module):
    def __init__(self):
        super(LigmaNetwork, self).__init__()
        
        self.conv1 = Conv2d(in_channels = 1, out_channels = 20, kernel_size = 5, stride = 1)
        self.relu1 = ReLU()
        self.pool1 = MaxPool2d(kernel_size = 2, stride = 2)
        
        self.conv2 = Conv2d(in_channels = 20, out_channels = 50, kernel_size = 5, stride = 1)
        self.relu2 = ReLU()
        self.pool2 = MaxPool2d(kernel_size = 2, stride = 2)

        win_size   = 28
        win_size  -= 5 - 1
        win_size //= 2
        win_size  -= 5 - 1
        win_size //= 2
        flat_size  = win_size * win_size * self.conv2.out_channels

        self.dense1 = Linear(in_features = flat_size, out_features = 500)
        self.relu3 = ReLU()
        self.dense2 = Linear(in_features = 500, out_features = 10)
        
        self.log_prob = LogSoftmax(dim = 1)
    
    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        x = flatten(x, 1)
        x = self.dense1(x)
        x = self.relu3(x)
        x = self.dense2(x)

        return self.log_prob(x)

In [None]:
model = LigmaNetwork().to(device)
print(model)

In [None]:
from torch.nn import NLLLoss
from torch.optim import Adam

loss_fn = NLLLoss()
optimizer = Adam(model.parameters(), lr = 1e-3)

In [None]:
from torch import Tensor
from tqdm import tqdm

max_epochs = 10
n_batch = len(train_dataloader)

i: int = 0
with tqdm(total = max_epochs * n_batch) as pbar:
    model.train()
    for epoch in range(max_epochs):
        x: Tensor
        y: Tensor
        y_pred: Tensor
        loss: Tensor
        for batch, (x, y) in enumerate(train_dataloader):
            i = pbar.n + 1
            
            # move data to gpu
            x, y = x.to(device), y.to(device)

            # compute prediction error
            y_pred = model(x)
            loss = loss_fn(y_pred, y)

            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update logs
            pbar.set_postfix({'i': i, 'epoch': epoch + 1, 'batch': batch + 1, 'loss': loss.item()})
            pbar.update()

In [None]:
import torch

predicted_digits: int = []
actual_digits: int = []
with torch.no_grad():
    x: Tensor
    y: Tensor
    y_pred: Tensor
    loss: Tensor
    for batch, (x, y) in enumerate(test_dataloader):
        # move data to device
        x, y = x.to(device), y.to(device)

        # make the predictions and calculate the validation loss
        y_pred = model(x)
        loss = loss_fn(y_pred, y)

        # move data to cpu 
        predicted_digits += y_pred.argmax(1).detach().cpu().tolist()
        actual_digits += y.detach().cpu().tolist()

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

display = ConfusionMatrixDisplay(confusion_matrix = confusion_matrix(actual_digits, predicted_digits))
display.plot()