In [19]:
import os

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torchvision.models.googlenet import GoogLeNet

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using device:', device)

Using device: cpu


## Load preprocess test data

In [16]:
# Where to load test data
DATA_BASE_PATH = os.getenv('PROJECT_DATA_BASE_DIR')
data_path = os.path.abspath(os.path.join(DATA_BASE_PATH, 'EMNIST'))

In [17]:
# Convert images to tensors and normalise (implicitly) in range [0, 1]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1745,), (0.3223,))
])

# Using updated link for dataset: https://cloudstor.aarnet.edu.au/plus/s/ZNmuFiuQTqZlu9W/download
testset = torchvision.datasets.EMNIST(root=data_path, split='byclass', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=1000,
                                         shuffle=True, num_workers=2)

classes = [str(i) for i in range(10)]
classes += list(map(chr, range(65, 91)))
classes += list(map(chr, range(97, 123)))

## Reload saved model

In [20]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 16, kernel_size=5)
        self.fc1 = nn.Linear(16 * 4 * 4, 120)
        self.fc2 = nn.Linear(120, 100)
        self.fc3 = nn.Linear(100, len(classes))
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1)

class EMNISTGoogLeNet(GoogLeNet):
    def __init__(self):
        super(EMNISTGoogLeNet, self).__init__(num_classes=len(classes), aux_logits=False)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3)

    def forward(self, x):
        return F.log_softmax(super(EMNISTGoogLeNet, self).forward(x), dim=1)

In [23]:
model_path = '../saved_models/emnist_centralised_inception.pth'

net = EMNISTGoogLeNet().to(device)

if torch.cuda.is_available():
    checkpoint = torch.load(model_path)
else:
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

net.load_state_dict(checkpoint['model_state_dict'])
net.eval()

EMNISTGoogLeNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (branch2): Sequential(
      (0): BasicConv2d(
        (conv): Conv2d(192, 96, kernel_size=(1,

## Analyse Overall Performance

In [26]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy of the network on the 10000 test images: {accuracy:.2f}')

Accuracy of the network on the 10000 test images: 87.83


## Analyse Per-Class performance

In [28]:
class_correct = list(0. for i in range(len(classes)))
class_total = list(0. for i in range(len(classes)))
with torch.no_grad():
    for data in testloader:
        images, labels = data[0].to(device), data[1].to(device)
        outputs = net(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(len(predicted)):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(len(classes)):
    accuracy = 100 * class_correct[i] / class_total[i]
    print(f'Accuracy of    {classes[i]}: {accuracy:.2f}%')

Accuracy of    0: 74.45%
Accuracy of    1: 88.33%
Accuracy of    2: 98.21%
Accuracy of    3: 99.61%
Accuracy of    4: 98.33%
Accuracy of    5: 93.20%
Accuracy of    6: 98.88%
Accuracy of    7: 99.56%
Accuracy of    8: 99.04%
Accuracy of    9: 98.93%
Accuracy of    A: 98.78%
Accuracy of    B: 97.53%
Accuracy of    C: 97.76%
Accuracy of    D: 88.58%
Accuracy of    E: 99.18%
Accuracy of    F: 88.12%
Accuracy of    G: 93.06%
Accuracy of    H: 97.50%
Accuracy of    I: 52.44%
Accuracy of    J: 88.34%
Accuracy of    K: 89.53%
Accuracy of    L: 94.07%
Accuracy of    M: 98.52%
Accuracy of    N: 99.11%
Accuracy of    O: 65.01%
Accuracy of    P: 94.42%
Accuracy of    Q: 95.88%
Accuracy of    R: 98.64%
Accuracy of    S: 97.21%
Accuracy of    T: 93.97%
Accuracy of    U: 97.10%
Accuracy of    V: 75.50%
Accuracy of    W: 86.97%
Accuracy of    X: 81.02%
Accuracy of    Y: 81.08%
Accuracy of    Z: 70.26%
Accuracy of    a: 94.22%
Accuracy of    b: 88.86%
Accuracy of    c: 0.93%
Accuracy of    d: 97.92%
A