In [18]:
import os

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

Using device: cuda


## Load preprocess test data

In [20]:
# Where to load test data
DATA_BASE_PATH = os.getenv('PROJECT_DATA_BASE_DIR')
data_path = os.path.abspath(os.path.join(DATA_BASE_PATH, 'EMNIST'))

In [21]:
# Convert images to tensors and normalise (implicitly) in range [0, 1]
transform = transforms.Compose([transforms.ToTensor()])

# Using updated link for dataset: https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download
testset = torchvision.datasets.EMNIST(root=data_path, split='byclass', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000,
                                         shuffle=False, num_workers=2)

classes = [str(i) for i in range(10)]
classes += list(map(chr, range(65, 91)))
classes += list(map(chr, range(97, 123)))

## Reload saved model

In [22]:
# TODO: Replace the copying of model with just importing it from a python file!
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 100)
        self.fc3 = nn.Linear(100, 62)
        self.pool = nn.MaxPool2d(2, 2)
#         self.drop_out = nn.Dropout()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

In [24]:
model_path = 'emnist_centralised.pth'

net = Net().to(device)

checkpoint = torch.load(model_path)
net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

Net(
  (conv1): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=256, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=62, bias=True)
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

## Analyse Overall Performance

In [25]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

Accuracy of the network on the 10000 test images: 83 %


## Analyse Per-Class performance

In [26]:
class_correct = list(0. for i in range(len(classes)))
class_total = list(0. for i in range(len(classes)))
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(predicted)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(len(classes)):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of     0 : 50 %
Accuracy of     1 : 96 %
Accuracy of     2 : 93 %
Accuracy of     3 : 97 %
Accuracy of     4 : 96 %
Accuracy of     5 : 92 %
Accuracy of     6 : 97 %
Accuracy of     7 : 98 %
Accuracy of     8 : 95 %
Accuracy of     9 : 97 %
Accuracy of     A : 93 %
Accuracy of     B : 85 %
Accuracy of     C : 96 %
Accuracy of     D : 84 %
Accuracy of     E : 89 %
Accuracy of     F : 87 %
Accuracy of     G : 75 %
Accuracy of     H : 85 %
Accuracy of     I : 45 %
Accuracy of     J : 74 %
Accuracy of     K : 28 %
Accuracy of     L : 90 %
Accuracy of     M : 93 %
Accuracy of     N : 97 %
Accuracy of     O : 78 %
Accuracy of     P : 87 %
Accuracy of     Q : 77 %
Accuracy of     R : 89 %
Accuracy of     S : 86 %
Accuracy of     T : 87 %
Accuracy of     U : 96 %
Accuracy of     V : 74 %
Accuracy of     W : 80 %
Accuracy of     X : 69 %
Accuracy of     Y : 72 %
Accuracy of     Z : 53 %
Accuracy of     a : 88 %
Accuracy of     b : 79 %
Accuracy of     c :  0 %
Accuracy of     d : 94 %
