In [1]:
import lightning as L
import torch
from shared_utilities_2 import PyTorchMLP, LightningModel, MNISTDataModule

In [2]:
pytorch_model = PyTorchMLP(num_features=784, num_classes=10)

lightning_model = LightningModel.load_from_checkpoint(
    checkpoint_path="model.ckpt", model=pytorch_model)

In [3]:
dm = MNISTDataModule()
dm.setup(stage="test")

In [4]:
import torchmetrics


test_dataloader = dm.test_dataloader()
acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)

for batch in test_dataloader:
    features, true_labels = batch

    with torch.inference_mode():
        logits = lightning_model(features)

    predicted_labels = torch.argmax(logits, dim=1)
    acc(predicted_labels, true_labels)

predicted_labels[:5]

tensor([1, 2, 3, 4, 5])

In [5]:
test_acc = acc.compute()
print(f'Test accuracy: {test_acc:.4f} ({test_acc*100:.2f}%)')

Test accuracy: 0.9558 (95.58%)


In [6]:
import matplotlib
import mlextend
from mlxtend.plotting import plot_confusion_matrix
from torchmetrics import ConfusionMatrix

cmat = ConfusionMatrix(task="multiclass", num_classes=len(class_dict))

for x, y in dm.test_dataloader():

    with torch.inference_mode():
        pred = lightning_model(x)
    cmat(pred, y)

cmat_tensor = cmat.compute()
cmat = cmat_tensor.numpy()

fig, ax = plot_confusion_matrix(
    conf_mat=cmat,
    class_names=class_dict.values(),
    norm_colormap=matplotlib.colors.LogNorm()
)
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
plt.show()

ModuleNotFoundError: No module named 'mlextend'