In [None]:
import torch
from torch import nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
class_names = range(10);

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_stack = nn.Sequential(
            nn.Linear(28*28, 64),            
            nn.Tanh(),            
            nn.Dropout(.2),
            
            nn.Linear(64, 128),
            nn.Sigmoid(),
            nn.Dropout(.2),

            nn.Linear(128, len(class_names)),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_stack(x)
        return logits

In [None]:
model = NeuralNetwork()
model.load_state_dict(torch.load("PyTorch-model.pth"))

In [None]:
# Test dataset
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=transforms.ToTensor()
)

# Dataloader
batch_size = 10000
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [None]:
# Get data
test_images, test_labels = next(iter(test_dataloader))

# Recognize digits
prediction_result = model(test_images);

# Get predicted labels
predicted_labels = prediction_result.argmax(1);

In [None]:
# Get randomly selected image for preview
preview_image_index = np.random.randint(0, test_images.shape[0] - 1)

plt.figure()
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(test_images[preview_image_index][0].numpy(), cmap=plt.cm.binary)

plt.xlabel(f"Actual: {test_labels[preview_image_index]} \n Predicted: {predicted_labels[preview_image_index]}", fontsize=20);

In [None]:
from sklearn import metrics
import seaborn as sns

In [None]:
print(metrics.classification_report(test_labels, predicted_labels))

In [None]:
confusion_matrix = metrics.confusion_matrix(test_labels, predicted_labels, labels=class_names);

plt.figure(figsize=(10, 8))
sns.heatmap(confusion_matrix,
            xticklabels=class_names,
            yticklabels=class_names,
            annot=True, fmt='g');

plt.xlabel('Predicted label', fontsize=20);
plt.ylabel('True label', fontsize=20);