# Computing performance evaluation metrics
Get some model predictions that we can then compute performance metrics for.

## 1. Load a trained model
To save time we will load a previously trained model.

In [4]:
from xai.constants import MODEL_DIR
from xai.models.simple_cnn import CNNClassifier

In [8]:
MODEL_FNAME = 'simple_cnn_test.pth'
MODEL_FPATH = MODEL_DIR / MODEL_FNAME

In [11]:
model = CNNClassifier()
model.load(MODEL_FPATH)

In [12]:
model

CNNClassifier(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)

## 2. Load some test data

In [13]:
from xai.data_handlers.mnist import load_mnist

In [44]:
# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=20, train=True, batch_size=20) # MNIST test loader
corpus_inputs, corpus_labels = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, test_labels = next(iter(test_loader)) # A set of inputs to explain

In [45]:
test_inputs.shape

torch.Size([20, 1, 28, 28])

In [46]:
test_labels.shape

torch.Size([20])

## 3. Get model predictions

In [47]:
import torch

In [48]:
output_probs = model.probabilities(test_inputs)
output_probs

tensor([[7.8536e-03, 4.8643e-02, 4.0603e-01, 3.5693e-02, 2.5306e-01, 3.7165e-02,
         3.3393e-02, 3.4064e-02, 5.8645e-02, 8.5451e-02],
        [1.0314e-01, 7.4720e-03, 3.3994e-02, 8.4979e-02, 5.3462e-02, 1.1687e-01,
         3.3820e-01, 6.9494e-02, 6.4233e-02, 1.2816e-01],
        [1.0356e-01, 1.0982e-01, 9.5850e-02, 1.7512e-01, 3.9650e-02, 1.7393e-01,
         4.6157e-02, 1.1098e-01, 8.8882e-02, 5.6060e-02],
        [5.0249e-03, 1.8694e-02, 3.8534e-02, 9.3481e-02, 1.9923e-01, 3.5183e-02,
         1.9083e-02, 1.5222e-01, 9.7385e-02, 3.4117e-01],
        [9.6585e-05, 8.5960e-05, 2.6702e-04, 1.6311e-03, 8.4505e-01, 5.9033e-02,
         7.7575e-03, 7.4896e-03, 5.1214e-02, 2.7372e-02],
        [1.7376e-02, 2.0326e-02, 7.8640e-02, 4.7556e-02, 2.9112e-02, 1.7138e-02,
         5.8611e-03, 7.1290e-01, 4.9932e-02, 2.1161e-02],
        [2.0887e-01, 3.4129e-02, 1.3556e-01, 4.1763e-02, 6.6740e-02, 1.2614e-01,
         1.3249e-01, 2.3718e-02, 1.1915e-01, 1.1144e-01],
        [7.0997e-02, 3.9057

In [49]:
output_probs.shape

torch.Size([20, 10])

In [53]:
predicted_classes = torch.argmax(output_probs, dim=1)
predicted_classes

tensor([2, 6, 3, 9, 4, 7, 0, 8, 1, 5, 2, 0, 0, 6, 3, 6, 2, 7, 4, 2])

In [54]:
test_labels

tensor([2, 5, 3, 9, 4, 7, 9, 9, 1, 8, 2, 0, 0, 6, 3, 6, 2, 9, 4, 2])

## 4. Calculate performance metrics
- accuracy
- auc
- f1

In [61]:
from sklearn.metrics import accuracy_score, auc, f1_score, roc_auc_score

In [57]:
accuracy_score(test_labels, predicted_classes)

0.75

In [60]:
f1_score(test_labels, predicted_classes, average='micro')

0.75

In [66]:
output_probs

tensor([[7.8536e-03, 4.8643e-02, 4.0603e-01, 3.5693e-02, 2.5306e-01, 3.7165e-02,
         3.3393e-02, 3.4064e-02, 5.8645e-02, 8.5451e-02],
        [1.0314e-01, 7.4720e-03, 3.3994e-02, 8.4979e-02, 5.3462e-02, 1.1687e-01,
         3.3820e-01, 6.9494e-02, 6.4233e-02, 1.2816e-01],
        [1.0356e-01, 1.0982e-01, 9.5850e-02, 1.7512e-01, 3.9650e-02, 1.7393e-01,
         4.6157e-02, 1.1098e-01, 8.8882e-02, 5.6060e-02],
        [5.0249e-03, 1.8694e-02, 3.8534e-02, 9.3481e-02, 1.9923e-01, 3.5183e-02,
         1.9083e-02, 1.5222e-01, 9.7385e-02, 3.4117e-01],
        [9.6585e-05, 8.5960e-05, 2.6702e-04, 1.6311e-03, 8.4505e-01, 5.9033e-02,
         7.7575e-03, 7.4896e-03, 5.1214e-02, 2.7372e-02],
        [1.7376e-02, 2.0326e-02, 7.8640e-02, 4.7556e-02, 2.9112e-02, 1.7138e-02,
         5.8611e-03, 7.1290e-01, 4.9932e-02, 2.1161e-02],
        [2.0887e-01, 3.4129e-02, 1.3556e-01, 4.1763e-02, 6.6740e-02, 1.2614e-01,
         1.3249e-01, 2.3718e-02, 1.1915e-01, 1.1144e-01],
        [7.0997e-02, 3.9057

In [68]:
roc_auc_score(test_labels, output_probs.detach(), multi_class='ovr')

0.9510051169590643

In [None]:
from xai.evaluation_metrics.performance.classification_metrics