# Computing performance evaluation metrics
Get some model predictions that we can then compute performance metrics for.

## 1. Load a trained model
To save time we will load a previously trained model.

In [1]:
from xai.constants import MODEL_DIR
from xai.models.simple_cnn import CNNClassifier

In [2]:
MODEL_FNAME = 'simple_cnn_50_epochs.pth'
MODEL_FPATH = MODEL_DIR / MODEL_FNAME

In [3]:
model = CNNClassifier()
model.load(MODEL_FPATH)

In [4]:
model

CNNClassifier(
  (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2_drop): Dropout2d(p=0.5, inplace=False)
  (fc1): Linear(in_features=320, out_features=50, bias=True)
  (fc2): Linear(in_features=50, out_features=10, bias=True)
)

## 2. Load some test data

In [5]:
from xai.data_handlers.mnist import load_mnist

In [6]:
# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=20, train=True, batch_size=20) # MNIST test loader
corpus_inputs, corpus_labels = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, test_labels = next(iter(test_loader)) # A set of inputs to explain

In [7]:
test_inputs.shape

torch.Size([20, 1, 28, 28])

In [8]:
test_labels.shape

torch.Size([20])

## 3. Get model predictions

In [9]:
import torch

In [10]:
output_probs = model.probabilities(test_inputs)
output_probs

tensor([[1.2896e-04, 3.5487e-03, 1.8493e-04, 1.1995e-05, 2.4939e-03, 4.9365e-02,
         4.7751e-01, 4.3563e-06, 4.6675e-01, 5.2533e-07],
        [3.3614e-08, 9.9988e-01, 1.1901e-05, 2.1136e-07, 7.8844e-05, 4.2648e-09,
         1.9139e-06, 3.6523e-06, 7.5206e-06, 1.6808e-05],
        [3.1638e-08, 1.3226e-06, 3.6393e-09, 8.9207e-07, 5.7577e-04, 6.2532e-08,
         1.2539e-09, 3.2449e-06, 1.0210e-06, 9.9942e-01],
        [1.7335e-02, 3.1937e-02, 2.6165e-01, 1.0273e-01, 1.3522e-03, 7.0981e-03,
         5.3454e-03, 1.6308e-01, 3.9781e-01, 1.1662e-02],
        [9.9994e-01, 5.8391e-11, 1.0658e-07, 5.4150e-08, 1.1030e-10, 3.7997e-05,
         2.0767e-05, 2.3085e-08, 1.8821e-06, 3.6488e-06],
        [8.0252e-05, 5.4530e-06, 9.1252e-01, 8.0681e-02, 1.3748e-06, 4.2844e-05,
         2.0577e-04, 2.0001e-04, 6.0600e-03, 2.0390e-04],
        [4.2411e-04, 4.4300e-07, 3.8390e-06, 1.2032e-09, 1.0454e-04, 1.4812e-08,
         9.9947e-01, 1.6225e-10, 1.2440e-08, 3.1007e-08],
        [1.8851e-07, 9.4675

In [11]:
output_probs.shape

torch.Size([20, 10])

In [12]:
predicted_classes = torch.argmax(output_probs, dim=1)
predicted_classes

tensor([6, 1, 9, 8, 0, 2, 6, 2, 4, 9, 8, 1, 5, 0, 3, 1, 5, 4, 2, 0])

In [13]:
test_labels

tensor([8, 1, 9, 8, 0, 2, 6, 2, 4, 9, 8, 1, 5, 0, 3, 1, 5, 4, 7, 0])

## 4. Calculate performance metrics
- accuracy
- auc
- f1

In [14]:
from sklearn.metrics import accuracy_score, auc, f1_score, roc_auc_score

In [15]:
accuracy_score(test_labels, predicted_classes)

0.9

In [16]:
f1_score(test_labels, predicted_classes, average='micro')

0.9

In [17]:
output_probs

tensor([[1.2896e-04, 3.5487e-03, 1.8493e-04, 1.1995e-05, 2.4939e-03, 4.9365e-02,
         4.7751e-01, 4.3563e-06, 4.6675e-01, 5.2533e-07],
        [3.3614e-08, 9.9988e-01, 1.1901e-05, 2.1136e-07, 7.8844e-05, 4.2648e-09,
         1.9139e-06, 3.6523e-06, 7.5206e-06, 1.6808e-05],
        [3.1638e-08, 1.3226e-06, 3.6393e-09, 8.9207e-07, 5.7577e-04, 6.2532e-08,
         1.2539e-09, 3.2449e-06, 1.0210e-06, 9.9942e-01],
        [1.7335e-02, 3.1937e-02, 2.6165e-01, 1.0273e-01, 1.3522e-03, 7.0981e-03,
         5.3454e-03, 1.6308e-01, 3.9781e-01, 1.1662e-02],
        [9.9994e-01, 5.8391e-11, 1.0658e-07, 5.4150e-08, 1.1030e-10, 3.7997e-05,
         2.0767e-05, 2.3085e-08, 1.8821e-06, 3.6488e-06],
        [8.0252e-05, 5.4530e-06, 9.1252e-01, 8.0681e-02, 1.3748e-06, 4.2844e-05,
         2.0577e-04, 2.0001e-04, 6.0600e-03, 2.0390e-04],
        [4.2411e-04, 4.4300e-07, 3.8390e-06, 1.2032e-09, 1.0454e-04, 1.4812e-08,
         9.9947e-01, 1.6225e-10, 1.2440e-08, 3.1007e-08],
        [1.8851e-07, 9.4675

In [18]:
roc_auc_score(test_labels, output_probs.detach(), multi_class='ovr')

0.9814327485380117

In [19]:
from xai.evaluation_metrics.performance.classification_metrics import calculate_accuracy_metrics

In [20]:
calculate_accuracy_metrics(test_labels, predicted_classes, output_probs)

{'accuracy': 0.9, 'f1': 0.9, 'auc': 0.9814327485380117}