## Imports

In [None]:
import os
from sklearn import metrics
import torch
import torchvision

## Device

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

## Paths

In [None]:
testing_data_directory = "../grouped-data/test/"
model_directory = "../models/"

## Load Data

In [None]:
testing_data = torchvision.datasets.ImageFolder(
    testing_data_directory,
    torchvision.transforms.Compose(
        [
            torchvision.transforms.Resize((224, 224)),
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
            ),
        ]
    ),
)
print("Testing dataset size:", len(testing_data))

## Create Confusion Matrix

In [None]:
confusion_matrices = []
for i, model_filename in enumerate(os.listdir(model_directory)):
    model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
    input_feature_count = model.fc.in_features
    output_feature_count = 5
    model.fc = torch.nn.Linear(input_feature_count, output_feature_count)
    model.load_state_dict(torch.load(os.path.join(model_directory, model_filename)))
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        inputs, targets = next(
            iter(
                torch.utils.data.DataLoader(
                    testing_data,
                    batch_size=len(testing_data),
                )
            )
        )
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, predictions = torch.max(outputs, 1)
        confusion_matrix = metrics.confusion_matrix(targets, predictions.cpu())
        confusion_matrices.append(confusion_matrix)

## Show Confusion Matrix

In [None]:
import matplotlib.pyplot as pyplot
fig, axs = pyplot.subplots(2, figsize=(15, 15))
for i, confusion_matrix in enumerate(confusion_matrices):
    overall_accuracy = confusion_matrix.trace() / confusion_matrix.sum()
    average_accuracy = (
        confusion_matrix.diagonal() / confusion_matrix.sum(axis=1)
    ).mean()
    confusion_matrix_display = metrics.ConfusionMatrixDisplay(
        confusion_matrix=confusion_matrix
    )
    confusion_matrix_display.plot(ax=axs[i])
    axs[i].set_title(i)

    print("Overall accuracy {}: {:.2f}%".format(i, overall_accuracy * 100))
    print("Average accuracy {}: {:.2f}%".format(i, average_accuracy * 100))

pyplot.show()