<img src="https://www.comet.com/images/logo_comet_light.png" width="300px"/>

# Interactive Confusion Matrix for Image Classification

## Introduction

Object Detection is one of the most popular applications of Machine Learning for Computer Vision. A detection model predicts both the class and location of each distinct object in an image. Object Detection models have a wide range of applications including manufacturing, surveillance, health care, and more. 

<a href="https://www.comet.com/anmorgan24/interactive-confusion-matrix/view/new/panels/?utm_source=Medium&utm_medium=referral&utm_content=TV_object_detection_blog"><img src="https://s11.gifyu.com/images/Untitled-design-35f8c53e6f7cae4c6.png" alt="Untitled-design-35f8c53e6f7cae4c6.png" border="0" with="850"/></a>

TorchVisionis a Python package that extends the PyTorch framework for computer vision use cases. In TorchVision’s detection module, developers can find [pre-trained object detection models](https://pytorch.org/vision/stable/models.html) that are ready to be finetuned on their own datasets. Pre-trained object detection models can also be found in the [PyTorch Hub](https://pytorch.org/hub/), [HuggingFace Hub](https://huggingface.co/docs/hub/index), and [other model zoos](https://docs.openvino.ai/latest/omz_demos.html), so computer vision engineers have a wide selection of models to choose from! But how can you systematically find the best model for a particular use-case? In this tutorial, we'll explore how to use an [experiment tracking](https://heartbeat.comet.ml/how-experiment-management-makes-it-easier-to-build-better-models-faster-ef99b6fee164) tool like [Comet](https://www.comet.com/signup/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog) to visually compare and evaluate object detection models.

Check out **[the public project here](https://colab.research.google.com/drive/1Fyrk6Br3EtahbFttIQh1cMgT1ztHKrzs#scrollTo=WXMuIndNl3-p)!**
A full article is also coming soon!

## The Code

In [None]:
!pip install comet_ml --root-user-action=ignore --quiet
import comet_ml

To instantiate your Comet Experiment, you'll need to grab your API key from your [account settings](https://www.comet.com/account-settings/profile/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog). If you don't already have an account, [create one here for free](https://www.comet.com/signup/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog).

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
YOUR_COMET_API_KEY = user_secrets.get_secret("YOUR-COMET-API-KEY")

In [None]:
comet_ml.init(api_key= YOUR_COMET_API_KEY,
              #project_name= "interactive-confusion-matrix-kaggle",
              #workspace= "anmorgan24")
experiment = comet_ml.Experiment()
experiment.set_name("mask_rcnn")

Let's organize the folders and import a [modified version of the TorchVision module](https://github.com/anmorgan24/vision) that will generate the desired outputs of our experiment:

In [None]:
%%shell
mv train/train/* train/; rm -rf train/train
mv valid/valid/* valid/; rm -rf valid/valid

pip install tdqm cython torchmetrics --quiet

# Download TorchVision repo with edits
git clone https://github.com/anmorgan24/vision.git

# copy relevant files to working directory
cd vision
cp references/detection/utils.py ../
cp references/detection/transforms.py ../
cp references/detection/coco_eval.py ../
cp references/detection/engine.py ../
cp references/detection/coco_utils.py ../

Import the necessary packages:

In [None]:
import os
import numpy as np
import pandas as pd
from joblib import dump
from PIL import Image, ImageDraw
from skimage import measure
import json
from random import randrange

import matplotlib.pyplot as plt
import matplotlib.patches as patches

import pycocotools

import torch
import torch.nn as nn
import torch.utils.data
from torch.optim.lr_scheduler import StepLR
from torchmetrics.detection.mean_ap import MeanAveragePrecision
import torchvision

#models
from torchvision import models
from torchvision.models.detection import faster_rcnn, RetinaNet, FCOS
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from torchvision.models.detection.anchor_utils import AnchorGenerator, DefaultBoxGenerator

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

Load the labels files as dictionaries:

In [None]:
with open('train_annotations') as json_file:
    train_labels = json.load(json_file)
with open('valid_annotations') as json_file:
    valid_labels = json.load(json_file)

Let's display a random example image from the validaton set, with the ground truth bounding box to make sure the image files have successfully loaded. 

In [None]:
# show example image
# generate random number within range of valid_labels length, convert to string and pad with zeros
number = str(randrange(len(valid_labels)+1)).zfill(3)

bbox_coords = valid_labels[int(number)]['bbox']
im = Image.open('./valid/image_id_'+number+'.jpg')
fig, ax = plt.subplots()
ax.imshow(im)
rect = patches.Rectangle((bbox_coords[0], bbox_coords[1]), bbox_coords[2], bbox_coords[3], linewidth=2, edgecolor='deeppink', facecolor='none')
ax.add_patch(rect)

plt.axis('off')
plt.show()

Import some of the functions and modules we'll be using from TorchVision directly to our working directory:

In [None]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T

Here we define our hyperparameters as a dictionary and log them to [Comet](https://www.comet.com/signup/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog). Note that the number of classes equals the number of objects to identify plus the background class. Because we have two label classes ("penguin" and "turtle'), we specify three total classes.

Comet will keep track of the sets of hyperparameters used with each model, for each experiment, so that we can easily debug and reproduce our results. To learn more about tuning hyperparameters in Comet, check out [this great article](https://heartbeat.comet.ml/hyperparameter-tuning-in-comet-e7aa637f124c) from Dr. Angelica Lo Duca (she literally wrote the book on Comet!).

In [None]:
hyper_params = {"lr": 0.005,
                "momentum" : 0.9,
                "weight_decay" : 0.0005,
                "step_size" : 3,
                "gamma" : 0.1,
                "num_epochs" : 3,
                # num_classes = num of objects to identify + background class
                "num_classes" : 3,
                "model_name": "mask_rcnn",
                "feature_extract": False}

experiment.log_parameters(hyper_params) 

We define a label dictionary to encode our [categorical labels](https://heartbeat.comet.ml/feature-engineering-for-categorical-data-897e98caea35). Remember that by default, the first class (class "0") is the background.

In [None]:
label_dict = {1: "penguin",
              2: "turtle"}

In [None]:
def calculate_f1_score(precision, recall):
    """
    Calculate f1 score from precision and recall values
    """
    if precision + recall > 0:
        f1_score = 2* ((precision * recall)/(precision + recall))

    else:
        f1_score = 0

    return f1_score

def get_transform(train):

    """Returns composed Torch transforms"""

    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

def my_log_func(logger):
    """
    Parses out the data collected in COCO_evaluator's metric_logger and logs it to Comet.
      Parameters:
          logger (metric_logger): The COCO_evaluator metric_logger object to be parsed out.
    """
    met_logger = str(logger).split('  ')
    for met in met_logger:
        temp_dict ={}
        temp_list = met.split(': ')
        temp_dict[temp_list[0]]=float(temp_list[1].split(' ')[0])
        [experiment.log_metric(metr, temp_dict[metr]) for metr in temp_dict.keys()]

def make_coco_boxes(tensor_bboxes):

    """Convert torch tensor Pascal VOC bboxes to COCO format for Comet annotations"""

    list_boxes=torch.Tensor.tolist(tensor_bboxes)
    coco_boxes = [[list_boxes[0], list_boxes[1], (list_boxes[2]-list_boxes[0]), (list_boxes[3]-list_boxes[1])]]
    return coco_boxes

def make_ped_points(binary_mask):

    """Converts binary mask labels to polygon point list"""

    contours = measure.find_contours(binary_mask, 0.5)
    ped_points=[]
    for contour in contours:
        contour = np.flip(contour, axis=1)
        segmentation = contour.ravel().tolist()
        ped_points.append(segmentation)
    return ped_points

def make_annotations(prediction):

    """ Parses out the COCO evaluator outputs into appropriately formated lists for Comet's annotations. """

    if len(prediction[0][1]['boxes']) == 0:
        return None

    annotations = [{
      "name": "image id: {}".format(prediction[0][0]),
      "data": []
    }]

    for i in range(len(prediction[0][1]['boxes'])):
        annotations[0]["data"].append({
          "label" : label_dict[torch.Tensor.tolist(prediction[0][2]['labels'])[i]], 
          "score": round((torch.Tensor.tolist(prediction[0][2]['scores'])[i]*100),2),
          "boxes" : make_coco_boxes(prediction[0][2]['boxes'][i]),
          #"boxes": [torch.Tensor.tolist(prediction[0][2]['boxes'])[i]], 
          # uncomment the following to log segmentation masks to Comet
          # use only with models that return mask segmentation predictions, else will throw error
          #"points": make_ped_points(prediction[0][2]['masks'][i].numpy().squeeze())
          "points": None
        })
    return annotations

def add_PIL_bboxes(img, bbox):
    img = img
    img1 = ImageDraw.Draw(img)  
    img1.rectangle(bbox, fill=None, outline ="deeppink", width=2)

ToPILImage = torchvision.transforms.ToPILImage()

We'll provide two options for models to use in our experiment: MaskRCNN and FastRCNN. Note that MaskRCNN is based on FastRCNN but with an extra component to detect and predict segmentation masks. We will not be using segmentation masks in this experiment, but feel free to use this tutorial on a dataset of your own that does!

In [None]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
  
    """ Returns the specified model with the specified parameters """

    if model_name =='mask_rcnn':

        """FastRCNN + MaskRCNN with ResNet50 backbone"""

        # load an object detection model pre-trained on COCO
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='MaskRCNN_ResNet50_FPN_Weights.DEFAULT')
         
        # get the number of input features for the classifier
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        # now get the number of input features for the mask classifier
        in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
        hidden_layer = 256
        # and replace the mask predictor with a new one
        model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                         hidden_layer,
                                                         num_classes)
        return model

    if model_name =='fast_rcnn':

        """ Fast RCNN with ResNet50 backbone """

        # load an instance segmentation model pre-trained on COCO
        model = torchvision.models.detection.maskrcnn_resnet50_fpn(weights='MaskRCNN_ResNet50_FPN_Weights.DEFAULT')
         
        # get the number of input features for the classifier
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

        return model   

We need to define a [custom PyTorch Dataset](https://pytorch.org/tutorials/beginner/data_loading_tutorial.html) class to pre-process and index our images and metadata to retrieve later. Note that while we don't use all of the keys created in the `target` dictionary below, our model and evaluator expect them to be there, so we define them but leave them empty.

In [None]:
class PenguinsVsTurtles(torch.utils.data.Dataset):
  
    def __init__(self, root, annotations_list, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to ensure that they are aligned
        self.imgs = list(sorted(os.listdir(os.path.join(root))))
        self.labels = [dct['category_id'] for dct in annotations_list]
        self.annotations_list = annotations_list

    def __getitem__(self, idx):
        # load images ad masks
        img_path = os.path.join(self.root, self.imgs[idx])
        label_bbox = self.annotations_list[idx]['bbox']
        img = Image.open(img_path).convert("RGB")

        masks = np.zeros((3,640,640))

        # get bounding box coordinates for each mask
        # note that we return bounding boxes in pascal voc format, as that's
        # what our model accepts. The COCO evaluator will convert to COCO format later.
        num_objs = 1
        boxes = []
        boxes.append([label_bbox[0], label_bbox[1], (label_bbox[0]+label_bbox[2]), (label_bbox[1]+label_bbox[3])])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class
        labels = torch.as_tensor((self.annotations_list[idx]['category_id'],), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

We now define our [dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html), which generate our inputs, apply transformations to our images, and [collate](https://plainenglish.io/blog/understanding-collate-fn-in-pytorch-f9d1742647d3) our batches.

In [None]:
# use our dataset and defined transformations
train_dataset = PenguinsVsTurtles('./train', annotations_list=train_labels, transforms=get_transform(train=True))
valid_dataset = PenguinsVsTurtles('./valid', annotations_list=valid_labels, transforms=get_transform(train=False))

# define training and validation data loaders
train_data_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=2, shuffle=True, num_workers=2,
    collate_fn=utils.collate_fn)

valid_data_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=1, shuffle=False, num_workers=2,
    collate_fn=utils.collate_fn)

We instantiate our model and log the model graph to [Comet](https://www.comet.com/signup/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog). We unfreeze our model layers for training, and pass our hyperparameter values as defined above in our hyperparameter dictionary.

In [None]:
# get the model using our helper function
model = initialize_model(hyper_params["model_name"], hyper_params["num_classes"], hyper_params["feature_extract"])
#log model graph
experiment.set_model_graph(model)

model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, 
                            lr=hyper_params["lr"],
                            momentum=hyper_params["momentum"], 
                            weight_decay=hyper_params["weight_decay"])

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size= hyper_params["step_size"],
                                               gamma=hyper_params["gamma"])

num_epochs = hyper_params["num_epochs"]

In [None]:
image_id = []
image_map_results = []

In [None]:
# train and validate
for epoch in range(hyper_params['num_epochs']):

    # train for one epoch, printing every 5 iterations
    metric_logger = train_one_epoch(model, optimizer, train_data_loader, device, epoch, print_freq=50)
    my_log_func(logger=metric_logger)

    # update the learning rate
    lr_scheduler.step()

    # Evaluate with validation dataset
    coco_evaluator, predictions= evaluate(model, valid_data_loader, device)

    # log images and annotations to Comet
    for prediction in predictions:
        image = ToPILImage(prediction[0][3])
        name = 'image id: {}'.format(str(prediction[0][0]))
        annotations = make_annotations(prediction)
        experiment.log_image(image, name=name, annotations= annotations)
      
      #image-level mAP
        if epoch == hyper_params['num_epochs']-1:
            metric = MeanAveragePrecision(iou_type="bbox")
            metric.update([prediction[0][2]], [prediction[0][1]])
            result = metric.compute()
            image_map_results.append([x.item() for x in result.values()])
            image_id.append(name)

    # log epoch mAP with IoU threshold >= .50 and maxDets = 100
    epoch_map = coco_evaluator.coco_eval['bbox'].stats[1]
    experiment.log_metric("epoch mAP", epoch_map)
    # log epoch mAR with 0.5 <= IoU thresh <= 0.95 and maxDets = 100
    epoch_mar = coco_evaluator.coco_eval['bbox'].stats[8]
    experiment.log_metric("epoch mAR", epoch_mar)

    #calculate and log epoch f1
    epoch_f1 = calculate_f1_score(epoch_map, epoch_mar)
    experiment.log_metric("epoch f1", epoch_f1)

    #labels - 1 to remove background class
    ground_truth_labels = [predictions[idx][0][1]['labels'].item()-1 for idx in range(len(predictions))]
    predicted_labels = [predictions[idx][0][2]['labels'][0].item()-1 for idx in range(len(predictions))]
    image_list = [ToPILImage(predictions[idx][0][3]) for idx in range(len(predictions))]
    for i, pil_img in enumerate(image_list):
        add_PIL_bboxes(pil_img, predictions[i][0][2]['boxes'].tolist()[0])

    experiment.log_confusion_matrix(
        ground_truth_labels,
        predicted_labels,
        images=image_list,
        row_label="Actual category",
        column_label = "Predicted category",
        title="Confusion Matrix: Penguins vs Turtles",
        file_name="confusion-matrix-epoch-{}.json".format(epoch),
        labels = list(label_dict.values())
        )

# create pandas DataFrame and log to Comet Data Panel
columns = [k for k in result.keys()]
results_dict = dict(zip(image_id, image_map_results))
experiment.log_table('image_level_map.csv', pd.DataFrame.from_dict(results_dict, orient='index', columns=columns))

In [None]:
experiment.end()

Now let's head on over to the Comet UI to explore our results and see how we might improve our model based on insights gained from the confusion matrices we logged.

## Using the Confusion Matrix

Select the experiment you’d like to view, then find the ‘Confusion Matrix’ tab on the lefthand sidebar. We can add multiple matrices to the same view, or switch between confusion matrices by selecting them from the drop-down menu at the top. By hovering over the different cells of the confusion matrix, you’ll see a quick break down of the samples from that cell. By clicking on a cell, we can also see specific instances where the model misclassified an image. By default, a maximum of 25 example images is uploaded per cell, but this can be reconfigured with the API.

[![accessing_conf_matrices.gif](https://s12.gifyu.com/images/accessing_conf_matrices.gif)](https://www.comet.com/anmorgan24/interactive-confusion-matrix/view/1oobhPiRdETUwTnWr9AGhOK0b/panels/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog)

Because we trained our model for three epochs and logged one matrix per epoch, here we can watch how our models improve over time. Are there particular images our model tends to struggle with? How can we use this information to augment our training data and improve our model’s performance? 

### View Specific Instances

In the example below, the model seems to get confused by images of white turtles, so maybe we can add some more examples in a future run. In any event, we can see that our model clearly makes fewer mistakes over time, eventually classifying all of the images correctly.

[![Screenshot-2023-05-01-at-7.29.50-PM.png](https://s12.gifyu.com/images/Screenshot-2023-05-01-at-7.29.50-PM.png)](https://www.comet.com/anmorgan24/interactive-confusion-matrix/view/1oobhPiRdETUwTnWr9AGhOK0b/panels/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog)

We can also click on individual images to examine them more closely. This can be especially helpful in object detection use cases, where visualizing the bounding box location can help us understand where the model is going wrong.

[![image-misclassification-instances.gif](https://s12.gifyu.com/images/image-misclassification-instances.gif)](https://www.comet.com/anmorgan24/interactive-confusion-matrix/view/1oobhPiRdETUwTnWr9AGhOK0b/panels/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog)

When examining specific instances of misclassifications, we can see that the model sometimes categorizes large boulders as turtles, and tends to get confused by one particularly unique breed of penguin. We could choose to augment our training data with images containing similar examples to improve performance.

### Aggregating Values

We can also choose three different methods of aggregating the cells in our confusion matrices: by count, percent by row, and percent by column. 

[![augmenting-conf-mat-values2.gif](https://s11.gifyu.com/images/augmenting-conf-mat-values2.gif)](https://www.comet.com/anmorgan24/interactive-confusion-matrix/view/1oobhPiRdETUwTnWr9AGhOK0b/panels/?utm_source=Kaggle_ANM&utm_medium=referral&utm_content=confusion_matrix_with_images_blog)

Thanks for making it all the way to the end, and we hope you found this tutorial useful!
