In [None]:
from torch import nn, optim
from sklearn.metrics import confusion_matrix

from model import FeedforwardNeuralNet
from data import MNISTDataset, CIFAR10Dataset
from trainer import Trainer

epochs = 1
lr = 0.01
train_batch_size = 64
eval_batch_size = 512

dataset = MNISTDataset(train_batch_size, eval_batch_size)
model = FeedforwardNeuralNet(dataset.input_size, dataset.output_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=lr)

trainer = Trainer(model, dataset, criterion, optimizer)
trainer.train(epochs)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

def plot_heatmap(conf_matrix: np.ndarray) -> None:
    plt.figure(figsize=(10, 7))
    sns.set(font_scale=1.4)  # for label size
    sns.heatmap(conf_matrix, annot=True, annot_kws={"size": 16}, fmt='g', cmap='Blues', cbar=False)

    plt.xlabel('Predicted labels')
    plt.ylabel('True labels')
    plt.title('Confusion Matrix')
    plt.show()

conf_matrix = confusion_matrix(trainer.true_labels, trainer.predictions)
plot_heatmap(conf_matrix)