In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import v2


In [66]:
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html includes a normalization; I don't think we want that for images
input_transform = v2.Compose([
    v2.ToImage(),
    v2.ToDtype(dtype=torch.float32),
])

train_set = torchvision.datasets.MNIST(
    "./data",
    download=True,
    transform=input_transform,
    train=True,
)
test_set = torchvision.datasets.MNIST(
    "./data",
    download=True,
    transform=input_transform,
    train=False,
)
print(f"Training set size: {len(train_set)}")
print(f"Test set size: {len(test_set)}")

Training set size: 60000
Test set size: 10000


In [34]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()


In [None]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [42]:
batch_size = 4

train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

for epoch in range(50):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 1000 == 999:    # print every 1000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

print('Finished Training')

[1,  1000] loss: 0.006
[1,  2000] loss: 0.002
[1,  3000] loss: 0.014
[2,  1000] loss: 0.009
[2,  2000] loss: 0.008
[2,  3000] loss: 0.003
[3,  1000] loss: 0.002
[3,  2000] loss: 0.003
[3,  3000] loss: 0.004
[4,  1000] loss: 0.003
[4,  2000] loss: 0.002
[4,  3000] loss: 0.006
[5,  1000] loss: 0.002
[5,  2000] loss: 0.003
[5,  3000] loss: 0.004
[6,  1000] loss: 0.005
[6,  2000] loss: 0.003
[6,  3000] loss: 0.002
[7,  1000] loss: 0.004
[7,  2000] loss: 0.004
[7,  3000] loss: 0.003
[8,  1000] loss: 0.004
[8,  2000] loss: 0.003
[8,  3000] loss: 0.011
[9,  1000] loss: 0.007
[9,  2000] loss: 0.007
[9,  3000] loss: 0.010
[10,  1000] loss: 0.009
[10,  2000] loss: 0.014
[10,  3000] loss: 0.016
[11,  1000] loss: 0.006
[11,  2000] loss: 0.006
[11,  3000] loss: 0.009
[12,  1000] loss: 0.004
[12,  2000] loss: 0.006
[12,  3000] loss: 0.008
[13,  1000] loss: 0.006
[13,  2000] loss: 0.005
[13,  3000] loss: 0.006
[14,  1000] loss: 0.004
[14,  2000] loss: 0.004
[14,  3000] loss: 0.004
[15,  1000] loss: 0

In [None]:
torch.save(net.state_dict(), "../src/model/centered_digit_model_weights.pth")


In [None]:
import numpy as np

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=1,
    shuffle=False,
    num_workers=2,
)

correct = 0
total = 0
failed_ids = []
for i, data in enumerate(test_loader):
    inputs, labels = data
    outputs = net(inputs)
    # Batch size of 0
    label = labels[0]
    output = outputs[0]

    total += 1
    if np.argmax(F.softmax(output, dim=0).tolist()) == label:
        correct += 1
    else:
        failed_ids.append(i)

print(f"Accuracy: {correct}/{total} ({(float(correct) / total):2.1%})")

In [92]:
# Prints a random failed example
import random

failed_id = random.choice(failed_ids)
(example, label) = test_set[failed_id]
inputs = torch.stack([example])
output = F.softmax(net(inputs), dim=1)[0]
labels = sorted(enumerate(output), key=lambda x: x[1], reverse=True)
print(f"Expected label: {label}")
print(f"Top labels: {", ".join(f"{label} ({prob:2.1%})" for label, prob in labels[0:3])}")
v2.ToPILImage()(test_set.data[failed_id]).show()


Expected label: 9
Top labels: 4 (86.9%), 9 (10.2%), 1 (2.0%)
