## Model Evaluation

This notebook contains a wide variety of methods to evaluate a trained model that was logged to this public [W&B Experiment](https://wandb.ai/mikasenghaas/bsc?workspace=user-mikasenghaas). 

In [None]:
import sys
sys.path.insert(0, "../src")

In [None]:
import os
import random

import torch
from tqdm import tqdm

# custom scripts
from config import *
from utils import *
from model import MODELS, FinetunedImageClassifier
from transform import ImageTransformer
from data import ImageDataset

## Load Trained Classifier

In [None]:
# specify which model to use
MODEL = "resnet18"
VERSION = "v1"
assert MODEL in MODELS, f"Specified model has to be one of {list(MODELS.keys())}"

In [None]:
# download the artifacts from the wandb server
import wandb

SAVE_PATH = path = os.path.join(BASEPATH, "artifacts", f"{MODEL}:{VERSION}")

api = wandb.Api()
artifact = api.artifact(f'mikasenghaas/bsc/{MODEL}:{VERSION}', type='model')
relative_path = artifact.download(root=SAVE_PATH)

print(f"{MODEL}:{VERSION} downloaded to {SAVE_PATH}")

In [None]:
# specify the paths of the most recently trained model
model_path = os.path.join(SAVE_PATH, f"{MODEL}.pt")
config_path = os.path.join(SAVE_PATH, "config.json")
transforms_path = os.path.join(SAVE_PATH, "transforms.pkl")

In [None]:
# load transform
transform = load_pickle(transforms_path)

In [None]:
# load model
config = load_json(config_path)
class2id = config['class2id']
id2class = {i:c for c,i in class2id.items()}
model = FinetunedImageClassifier(**config)
model.load_state_dict(torch.load(model_path))

model.eval()

## Prediction Examples on Test Split

We sample a random batch of `16` frames from the dataset and visualise the true and predicted label.

_Note: As of now, the data splits are on the randomised frames, which means that there is a chance for the model to have seen frames that are very similar to the frames in the test set._

In [None]:
# define test split and loader
from torch.utils.data import DataLoader
from torch.nn.functional import softmax

test_data = ImageDataset(split="test", include_classes=list(class2id.keys()), ratio=1.0)
test_loader = DataLoader(test_data, 16)

In [None]:
# load batch of 16 images
test_list = list(iter(test_loader))
idx = random.randint(0, len(test_list))
images, labels = test_list[idx]

# predict
logits = model(transform(images))
probs = softmax(logits, 1)
max_probs, preds = torch.max(probs, 1)

# show images with ground truth
show_images(images, titles=[f"True: {id2class[labels[i].item()]}\nPred: {id2class[preds[i].item()]} ({round(100 * max_probs[i].item(), 1)}%)" for i in range(len(preds))], show=True)

In [None]:
# predict on single image
image, label = images[0], labels[0]

logits = model(transform(image).unsqueeze(0))
probs = softmax(logits, 1)
prob, pred = torch.max(probs, 1)

show_image(image, title=f"Label: {id2class[label.item()]}\nPred: {id2class[pred.item()]} ({round(prob.item() * 100,1)}%)", show=True)

## Evaluation Metrics on Test Split

We predict on all samples in the test split and measure common metrics for classification experiments, like accuracy, precision, recall and F1 score

In [None]:
# predict all test samples
images_mispred = []
y_true_mispred, y_pred_mispred = [], []
y_true, y_pred = [], []
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        logits = model(transform(images)) # B, C
        preds = logits.argmax(-1) # B
    
        for image, true, pred in zip(images, labels.tolist(), preds.tolist()):
            y_true.append(true)
            y_pred.append(pred)
            if true != pred:
                images_mispred.append(image)
                y_true_mispred.append(true)
                y_pred_mispred.append(pred)

In [None]:
# classification report
import pandas as pd
from sklearn.metrics import classification_report

labels = list(class2id.keys())
report = classification_report(y_true, y_pred, target_names=labels, output_dict=True)
pd.DataFrame(report).T

In [None]:
# confusion matrix
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix

conf_matrix = confusion_matrix(y_true, y_pred)

In [None]:
# visualise confusion matrix with absolute counts
_, ax = plt.subplots(figsize=(12, 12))
sns.heatmap(conf_matrix, annot=True, fmt='g', xticklabels=labels, yticklabels=labels, ax=ax);

## Mispredictions

Mispredictions can be informative to investigate how to further improve a machine learning model.

In [None]:
print(y_true_mispred)
print(y_pred_mispred)

In [None]:
# get 16 random mispredictions
idxs = random.sample(range(len(images_mispred)), 16)
show_images(
    torch.cat([image.unsqueeze(0) for image in images_mispred])[idxs],
    titles=[f"True: {id2class[y_true_mispred[i]]}\nPred: {id2class[y_pred_mispred[i]]}" 
            for i in idxs], show=True)

We can identify the following typical error sources:

- Inherently difficult to predict frames (white wall, close-up of bookshelf)
- Similarities of locations (e.g. coloured areas on the different floors)
- Transistions between areas

## Predict on Video Clips

Real-time inference similar to the final deployed model on mobile devices to get a feel for the consistency of the predictions.

In [None]:
import torchvision

# get random video path
path = os.path.join(RAW_DATA_PATH)
clip = random.sample(os.listdir(path), 1)[0]
video_path = os.path.join(path, clip, "video.mov")
print(f"Predicting on video {video_path}")

# load video
video, _, _ = torchvision.io.read_video(video_path, start_pts=0, end_pts=10, pts_unit="sec", output_format="TCHW")

fig, ax = plt.subplots()
ax.set_title(f"{MODEL}:{VERSION}") # type: ignore
ax.set_xticks([]) # type: ignore
ax.set_yticks([]) # type: ignore

img = ax.imshow(transforms.ToPILImage()(video[0])) # type: ignore

def animate(i):
    # transforms
    logits = model(transform(video[i]).unsqueeze(0))
    probs = softmax(logits, 1)
    prob, pred = torch.max(probs, 1)
    prob, pred = prob.item(), pred.item()

    print(f"Prediction: {id2class[pred]} (Confidence: {round(prob * 100, 2)}%)", end="\r")

    img.set_array(transforms.ToPILImage()(video[i])) # type: ignore

    return [img]

a = animation.FuncAnimation(fig, animate, frames=len(video), interval=1, blit=True)

from IPython.display import HTML
HTML(a.to_jshtml())
a

In [None]:
import cv2
import torchvision

# get random video path
path = os.path.join(RAW_DATA_PATH)
clip = random.sample(os.listdir(path), 1)[0]
video_path = os.path.join(path, clip, "video.mov")
print(f"Predicting on video {video_path}")

video, _, _ = torchvision.io.read_video(video_path, start_pts=0, end_pts=1, pts_unit="sec", output_format="TCHW")
# set up video capture
cap = cv2.VideoCapture(video_path)

while True:
    # read next frame
    ret, frame = cap.read()
    frame_tensor = torch.tensor(frame).permute(2,0,1) # C, H, W
    frame_tensor = frame_tensor[[2,1,0], :, :] # change channel to RGB from BGR
    
    #print(model(transform(video[0].unsqueeze(0))).argmax(-1))
    #print(model(transform(frame_tensor[0].unsqueeze(0))).argmax(-1))
    
    # show_image(transform(video[0]), show=True)
    # show_image(transform(frame_tensor[0]), show=True)
    
    if frame_tensor == None:
        break
    
    # predict frame
    logits = model(transform(frame_tensor).unsqueeze(0))
    probs = softmax(logits, 1)
    prob, pred = torch.max(probs, 1)
    prob, pred = prob.item(), pred.item()
    class_label = id2class[pred]
    
    text = f"{class_label} ({round(100 * prob, 1)}%)"
        
    # overlay the prediction on the frame
    cv2.putText(frame, text, (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 255, 0), 1)
    
    # display the frame with the prediction overlaid
    cv2.imshow(f"{MODEL}:{VERSION}", frame)
    
    # exit the loop if the 'q' key is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# release the video capture and close the window
cap.release()
cv2.destroyAllWindows()

In [None]:
image, label = test_data[random.randint(0, len(test_data))]

pred = model(transform(image).unsqueeze(0)).argmax(-1)
show_image(image, title=id2class[pred.item()], show=True)

In [None]:
torch.cat((torch.rand(15,3,224,224), image.unsqueeze(0))).shape

In [None]:
images, labels = next(iter(test_loader))

preds = model(transform(images)).argmax(-1)
show_images(images, titles=[id2class[pred.item()] for pred in preds], show=True)

In [None]:
model.__call__(torch.rand(16,3,224,224))