# Performance Metrics and Calibration Curves.

In this week, we're going to evaluate our semi-supervised model from Week 5 and take a deeper look at how it performs for each class, as well as the calibration of the confidence it produces.

To do this, we're going to follow these steps:
1. Initialise the notebook by loading necessary libraries.
2. Load the test data for evaluation. 
3. Initialise our student network and load the pretrained weights from last week.
4. Test the student network on the test data, collecting all predictions, the associated confidence, and the ground-truth class.
5. Create a confidence calibration curve.
6. Investigate the precision and recall of each class.

# 1) Initialise the notebook with libraries

Let's load in any libraries we will use in this notebook. We're going to install weights and biases (wandb) as it does not come by default in this environment.

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import glob
from PIL import Image

import tqdm

# 2) Load the labelled test data and prepare for evaluation

## 2a) Load the test dataset

Last week, we saw that we can load data with this format by:
1. Applying transformations -- by default, the most basic is [transforms.ToTensor()](https://pytorch.org/vision/stable/generated/torchvision.transforms.ToTensor.html), [transforms.Resize()](https://pytorch.org/vision/stable/generated/torchvision.transforms.Resize.html) and [transforms.Normalize()](https://pytorch.org/vision/stable/generated/torchvision.transforms.Normalize.html).
2. Load the datasets in with [torchvision.datasets.ImageFolder](https://pytorch.org/vision/stable/generated/torchvision.datasets.ImageFolder.html).

**If this is unfamiliar, please review the Week 4 practical.**

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Resize((224, 224), antialias = True), 
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

test_dataset = torchvision.datasets.ImageFolder('../Week 5/stanford_dogs_semi-supervised/labelled/Test/', transform = transform)

class_labels = test_dataset.classes
num_classes = len(class_labels)

print(f'Dataset has {num_classes} classes, which are: {class_labels}')

## 2b) Initialise the data loaders

**Your turn:** Using a batch size of 8, initialise the data loaders for the ```test_dataset``` below into a variable called ```testloader```.

If you are confused, you can check the Week 3 and Week 4 practical sheets or review the IFN680 practical support sheet.

In [None]:
######### Your code goes here ################

# 3) Initialise and load weights for our student network

We can use the ```create_classifier``` function from last week's practical to create our network. We can then load in our existing weights for this model.

If this is unfamiliar, you can review the Week 4 and Week 5 practical sheets or read more in the IFN680 Practical Support Sheet.

In [None]:
def create_classifier(nc):
    #load the model and initialise with pre-trained weights
    model = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
    
    #adapt the architecture to the correct number of classes
    in_features = model.fc.in_features
    model.fc = nn.Linear(model.fc.in_features, nc)
    
    #move the model to the GPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #this line checks if we have a GPU available
    model = model.to(device)
    
    return model

model = create_classifier(num_classes)
model.load_state_dict(torch.load('Week5_ResNet_student_best.pth')) #I am loading in my weights from the Week 5 practical -- you can change this to your own if you'd like.


# 4) Test the student network on the test data, collecting all predictions, the associated confidence, and the ground-truth class.

This will follow a very similar process to what we typically do when testing on the validation dataset during training.

One thing that we need to do differently here is normalize our class scores. The ResNet architecture will output unnormalised scores, that haven't yet gone through a Softmax layer. We can remedy this by applying the [torch.softmax()](https://pytorch.org/docs/stable/generated/torch.nn.functional.softmax.html#torch.nn.functional.softmax) function to the output from the model (see the IFN680 support material for more information). Please review the Week 3 lecture slides (slide 51) and the IFN680 support material for more information on the Softmax layer.

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #this line checks if we have a GPU available
model.eval()

all_pred_conf = []
all_pred_class = []
all_gt_class = []

for i, data in  tqdm.tqdm(enumerate(testloader, 0), total = len(testloader)):
    inputs, labels = data

    #A. move the inputs to the GPU if available
    inputs = inputs.to(device)

    #B.  forward pass to find the outputs
    outputs = model(inputs)

    #use the softmax function to convert the class scores into normalized confidence scores
    outputs = torch.softmax(outputs, dim = 1)
    
    #we need to know which label is predicted by our model. This is the class with the highest class score.
    predicted_class = torch.argmax(outputs, axis = 1)
    
    #we also need to know the confidence predicted  by our model
    predicted_confidence, _ = torch.max(outputs, axis = 1)
    
    #convert all our important information to list format, and store in the respective lists for processing later
    all_pred_conf += predicted_confidence.cpu().tolist()
    all_pred_class += predicted_class.cpu().tolist()
    all_gt_class += labels.tolist()


# 5) Create a confidence calibration curve.

As discussed in the Week 6 lecture, calibration curves are useful for understanding how well the confidence scores predicted by a classification model align with the actual accuracy of the model, which is critical in some applications where not just the label but also the uncertainty of the prediction is important. A well-calibrated model should have its predicted confidence probabilities close to the true probability that the prediction is correct.

Below, we're creating a calibration curve which tests confidence calibration at 10 intervals between 0-100%. Please see the Week 6 lecture slides for more information.

In [None]:
#create a variable that holds the confidence intervals we will check on a confidence calibration curve
conf_ranges = [[0, 10], [10, 20], [20, 30], [30, 40], [40, 50], [50, 60], [60, 70], [70, 80], [80, 90], [90, 100]] 

#convert our previously collected lists into numpy arrays so that we can easily manipulate them
all_pred_conf = np.array(all_pred_conf)
all_pred_class = np.array(all_pred_class)
all_gt_class = np.array(all_gt_class)

actual_accuracy = []
conf_level = []
for conf_int in conf_ranges:
    lower = conf_int[0]/100 #convert between 0-1
    upper = conf_int[1]/100 #convert between 0-1

    #create a mask that will collect predictions in the confidence interval -- it must be above the lower thresh AND below the upper thresh
    mask = (all_pred_conf >= lower) & (all_pred_conf < upper)
    
    #collect all predictions and GT data within the range using the mask
    preds = all_pred_class[mask]
    gt = all_gt_class[mask]
    
    #find the accuracy of this bin by checking how many correct/total
    correct = np.sum(preds == gt)
    total = len(preds)
    accuracy = correct/total
    actual_accuracy += [accuracy] #save the accuracy for this bin to plot later
    conf_level += [(upper + lower)/2] #this is the average confidence level for this confidence interval (not necessarily for the predictions in the bin though), we will use this for plotting later
    
plt.figure()
plt.bar(conf_level, actual_accuracy, width = 0.09)
plt.plot([0, 1], [0, 1], 'r--') #our well-calibrated line
plt.xlabel('Confidence')
plt.ylabel('Accuracy')
plt.title('Confidence Calibration Curve')
plt.show()

## Consider:

Would you consider this model to be well-calibrated? If not, is it over-confident or under-confident? What's going on for confidences below 0.3?

**Your turn:** Add some code to the above cell to try to create another plot that shows the number of data points that were used to calculate the accuracy of each bin. This may give you more information about how reliable certain parts of the confidence calibration curve is.

In [None]:
#### Your code goes below


# 6) Investigate the precision and recall of each class.

To calculate precision and recall for each class, we need to find the number of true positives for each class (actual examples from that class that were correctly classified), false positives (examples not belonging to that class that were classified as the class incorrectly), and false negatives (actual examples from that class that were incorrectly classified as a different class).I've done that below. 

**Your turn:** Use the TP, FP and FN counts to calculate the precision and recall for the class. Review the Week 6 lecture notes for the precision and recall formulas.

In [None]:
for cls_idx in range(num_classes):
    
    mask_gt_pos = all_gt_class == cls_idx #mask for which samples belong to the current positive class?
    mask_pred_pos = all_pred_class == cls_idx #mask for which samples were predicted as the current positive class?
    
    tp_mask = mask_gt_pos & mask_pred_pos #belongs to the pos class AND was predicted as the pos class
    tp_count = np.sum(tp_mask) #the count of TP for this class
    
    fp_mask = ~mask_gt_pos & mask_pred_pos #does not belong to the pos class AND was predicted as the pos class -- note the use of ~to find the inverse mask
    fp_count = np.sum(fp_mask)#the count of FP for this class
    
    fn_mask = mask_gt_pos & ~mask_pred_pos #belongs to the pos class AND was not predicted as the pos class -- note the use of ~to find the inverse mask 
    fn_count = np.sum(fn_mask)#the count of FN for this class
    
    ### Your code goes below
    precision = ...
    recall = ...
    
    
    print(f'For class {class_labels[cls_idx]}: ')
    print(f'              Precision: {precision}')
    print(f'              Recall: {recall}')
    