In [1]:
import kagglehub

# download dataset
data_dir = kagglehub.dataset_download("andrewmvd/dog-and-cat-detection")
print("Path to downloaded dataset:", data_dir)

  from .autonotebook import tqdm as notebook_tqdm


Path to downloaded dataset: /home/hongong/.cache/kagglehub/datasets/andrewmvd/dog-and-cat-detection/versions/1


In [2]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import patches
from collections import Counter
import cv2
from glob import glob
from tqdm import tqdm
from termcolor import colored

import torch
from torch import nn
from torch import optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

import albumentations as A
from albumentations.pytorch import ToTensorV2

# Data preparation

In [3]:
class CustomVOCDataset(torchvision.datasets.VOCDetection):
    def __init__(self, class_mapping, S=7, B=2, C=20, custom_transforms=None):
        self.S = S  # Grid size S x S
        self.B = B  # Number of bounding boxes
        self.C = C  # Number of classes
        self.class_mapping = class_mapping  # Mapping of class names to class indices
        self.custom_transforms = custom_transforms

    def __getitem__(self, index):
        # get an image and its target (annotations) from the VOC dataset
        image, target = super(CustomVOCDataset, self).__getitem__(index)
        img_width, img_height = image.size

        # convert target annotations to YOLO format bounding boxes
        boxes = self.convert_to_yolo_format(
            target, img_width, img_height, self.class_mapping
        )
        just_boxes = boxes[:, 1:]
        labels = boxes[:, 0]

        # transform
        if self.custom_transforms:
            sample = {"image": np.array(image), "bboxes": just_boxes, "labels": labels}
            sample = self.custom_transforms(**sample)
            image = sample["image"]
            boxes = sample["bboxes"]
            labels = sample["labels"]

        # create an empty label matrix for YOLO ground truth
        label_matrix = torch.zeros((self.S, self.S, self.C + 5 * self.B))
        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.float32)
        image = torch.as_tensor(image, dtype=torch.float32)

        # iterate through each bounding box in YOLO format
        for box, class_label in zip(boxes, labels):
            x, y, width, height = box.tolist()
            class_label = int(class_label)

            i, j = int(self.S * y), int(self.S * x)
            x_cell, y_cell = self.S * x - j, self.S * y - i

            # calculate the width and height of the box relative to the grid cell
            width_cell, height_cell = (
                width * self.S,
                height * self.S,
            )

            # if no object has been found in this specific cell (i, j) before
            if label_matrix[i, j, 20] == 0:
                # mark that an object exists in this cell
                label_matrix[i, j, 20] = 1

                # store the box coordinates as an offset from the cell boundaries
                box_coordinates = torch.tensor(
                    [x_cell, y_cell, width_cell, height_cell]
                )

                # set the box coordinates in the label matrix
                label_matrix[i, j, 21:25] = box_coordinates

                # set the one-hot encoding for the class label
                label_matrix[i, j, class_label] = 1

        return image, label_matrix

In [4]:
def convert_to_yolo_format(target, img_width, img_height, class_mapping):
    """
    Convert annotation data from VOC format to YOLO format

    Parameters:
        target (dict): annotation data from VOCDetection dataset.
        img_width (int): width of the image.
        img_height (int): height of the image.
        class_mapping (dict): mapping of class names to class to interger IDS.

    Returns:
        torch.Tensor: Tensor of shape [N, 5],  for N bounding boxes.
                    each with [class_id, x_center, y_center, width, height]
    """

    # Extract the list of a n n o t a t i o n s from the target dic ti ona ry .
    annotations = target["annotation"]["object"]

    # get the real width and height of the image from the annotation
    real_width = int(target["annotation"]["size"]["width"])
    real_height = int(target["annotation"]["size"]["height"])

    # ensure that annotations is a list, even if there's only one object
    if not isinstance(annotations, list):
        annotations = [annotations]

    # initialize an empty list to store the converted bounding boxes
    boxes = []

    # loop through each annotation and convert it to YOLO format
    for anno in annotations:
        xmin = int(anno["bndbox"]["xmin"]) / real_width
        xmax = int(anno["bndbox"]["xmax"]) / real_width
        ymin = int(anno["bndbox"]["ymin"]) / real_height
        ymax = int(anno["bndbox"]["ymax"]) / real_height

        # Calculate the center coordinates , width , and height of the bounding box
        x_center = (xmin + xmax) / 2
        y_center = (ymin + ymax) / 2
        width = xmax - xmin
        height = ymax - ymin

        # retrieve the class name from the annotation and map it to an interger ID
        class_name = anno["name"]
        class_id = class_mapping[class_name] if class_name in class_mapping else 0

        # append the YOLO formatted bbox to the list.
        boxes.append([class_id, x_center, y_center, width, height])

    # convert the list of  boxes to a torch tensor
    return np.array(boxes)

In [5]:
def intersection_over_union(boxes_preds, boxes_labels, box_format="midpoint"):
    """
    Calculate intersection over union (Jaccard overlap) between prediction and target boxes

    Parameters:
        boxes_preds (torch.Tensor): Predictions of Bounding boxes (BATCH_SIZE, 4)
        boxes_labels (torch.Tensor): Correct labels of Bounding boxes (BATCH_SIZE, 4)
        box_format (str): midpoint/corners, if boxes (x,y,w,h) or (x1,y1,x2,y2)

    Returns:
        torch.Tensor: Intersection over union for all examples
    """

    # check if the box format is "midpoint"
    if box_format == "midpoint":
        # calculate corrdinates of top-left (x1, y1) and bottom-right (x2, y2) points for predicted boxes
        box1_x1 = boxes_preds[..., 0:1] - boxes_preds[..., 2:3] / 2
        box1_y1 = boxes_preds[..., 1:2] - boxes_preds[..., 3:4] / 2
        box1_x2 = boxes_preds[..., 0:1] + boxes_preds[..., 2:3] / 2
        box1_y2 = boxes_preds[..., 1:2] + boxes_preds[..., 3:4] / 2

        # calculate corrdinates of top-left (x1, y1) and bottom-right (x2, y2) points for ground truth boxes
        box2_x1 = boxes_labels[..., 0:1] - boxes_labels[..., 2:3] / 2
        box2_y1 = boxes_labels[..., 1:2] - boxes_labels[..., 3:4] / 2
        box2_x2 = boxes_labels[..., 0:1] + boxes_labels[..., 2:3] / 2
        box2_y2 = boxes_labels[..., 1:2] + boxes_labels[..., 3:4] / 2

    # check if the box format is "corners"
    if box_format == "corners":
        # calculate corrdinates for predicted boxes
        box1_x1 = boxes_preds[..., 0:1]
        box1_y1 = boxes_preds[..., 1:2]
        box1_x2 = boxes_preds[..., 2:3]
        box1_y2 = boxes_preds[..., 3:4]

        # calculate corrdinates for ground truth boxes
        box2_x1 = boxes_labels[..., 0:1]
        box2_y1 = boxes_labels[..., 1:2]
        box2_x2 = boxes_labels[..., 2:3]
        box2_y2 = boxes_labels[..., 3:4]

    # calculate the corrdinates of the intersection rectangle
    x1 = torch.max(box1_x1, box2_x1)
    y1 = torch.max(box1_y1, box2_y1)
    x2 = torch.min(box1_x2, box2_x2)
    y2 = torch.min(box1_y2, box2_y2)

    # calculate the area of intersection rectangle
    intersection = (x2 - x1).clamp(0) * (y2 - y1).clamp(0)

    # calculate the area of both prediction and ground truth boxes
    box1_area = abs((box1_x2 - box1_x1) * (box1_y2 - box1_y1))
    box2_area = abs((box2_x2 - box2_x1) * (box2_y2 - box2_y1))

    # calculate the area of union of both boxes, add small epsilon to avoid division by zero
    return intersection / (box1_area + box2_area - intersection + 1e-6)

In [6]:
def non_max_suppression(boxes, iou_threshold, threshold, box_format="corners"):
    """
    Perform Non-maxium Suppression on a list of bbox.

    Parameters:
        boxes (list): list of lists containing all bboxes with each bboxes [class_pred, prob_score, x1, y1, x2, y2]
        specified as [class_pred, prob_score, x1, y1, x2, y2]
        iou_threshold (float): threshold where predicted bboxes is correct
        threshold (float): threshold to remove predicted bboxes (independent of IoU)
        box_format (str): "midpoint" or "corners" used to specify bboxes

    Returns:
        list: bboxes after performing NMS given a specific IoU threshold
    """

    # check the data type of the input parameter
    assert type(bboxes) == list

    # filter predicted bbox based on probability threshold
    bboxes = [box for box in bboxes if box[1] > threshold]

    # sort bbox by their probability in descending order
    bboxes = sorted(bboxes, key=lambda x: x[1], reverse=True)

    # list to store bbox after NMS
    bboxes_after_nms = []

    # continue looping until the list of bbox is empty
    while bboxes:
        # pop the bbox with the highest probability
        chosen_box = bboxes.pop(0)

        # remove bbox with IoU higher than the threshold
        bboxes = [
            box
            for box in bboxes
            if box[0] != chosen_box[0]
            or intersection_over_union(
                torch.tensor(chosen_box[2:]),
                torch.tensor(box[2:]),
                box_format=box_format,
            )
            < iou_threshold
        ]

        # add the chosen bbox to the list of bbox after NMS
        bboxes_after_nms.append(chosen_box)

    return bboxes_after_nms

In [7]:
def mean_average_precision(
    pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint", num_classes=20
):
    """
    Calculate mean average precision for object detection (mAP)

    Parameters:
        pred_boxes (list): list of lists containing all bboxes with each bboxes [class_pred, prob_score, x1, y1, x2, y2]
        true_boxes (list): list of lists containing all bboxes with each bboxes [class_pred, x, y, width, height]
        iou_threshold (float): threshold where predicted bboxes is correct
        box_format (str): "midpoint" or "corners" used to specify bboxes
        num_classes (int): number of classes

    Returns:
        float: mAP value across all classes given a specific IoU threshold
    """

    # list to store mAP for each class
    average_precisions = []

    # small epsilon value to avoid division by zero
    epsilon = 1e-6

    for c in range(num_classes):
        detections = []
        ground_truths = []

        # iterate through all predictions and targets, and only add those belonging to the current class 'c'
        for detection in pred_boxes:
            if detection[0] == c:
                detections.append(detection)

        for true_box in true_boxes:
            if true_box[0] == c:
                ground_truths.append(true_box)

        # find the number of boxes for each training example.
        # the Counter here counts the number of target boxes we have
        # for each training example, so if image 0 has 3, and image 1 has 5,,
        # we'll have a dictionary with {0: 3, 1: 5}
        amount_bboxes = Counter([gt[0] for gt in ground_truths])

        # we then loop through each key, val in this dictionary and convert it to the following (for the same example):
        # {0: torch.tensor([0, 0, 0]), 1: torch.tensor([0, 0, 0, 0, 0])}
        for key, val in amount_bboxes.items():
            amount_bboxes[key] = torch.zeros(val)

        # sort the detections by their probability value, index 2 is the probability
        detections.sort(key=lambda x: x[2], reverse=True)
        TP = torch.zeros((len(detections)))
        FP = torch.zeros((len(detections)))
        total_true_bboxes = len(ground_truths)

        # if there are no ground truth boxes for this class, it can be safely skipped
        if total_true_bboxes == 0:
            continue

        for detection_idx, detection in enumerate(detections):
            # only consider ground truth boxes with the same training index as the prediction
            ground_truth_img = [
                bbox for bbox in ground_truths if bbox[0] == detection[0]
            ]

            num_gts = len(ground_truth_img)
            best_iou = 0

            for idx, gt in enumerate(ground_truth_img):
                iou = intersection_over_union(
                    torch.tensor(detection[3:]),
                    torch.tensor(gt[3:]),
                    box_format=box_format,
                )

                if iou > best_iou:
                    best_iou = iou
                    best_gt_idx = idx

            if best_iou > iou_threshold:
                # only detect ground truth detection once
                if amount_bboxes[detection[0]][best_gt_idx] == 0:
                    # true positive and add this bounding box to seen
                    TP[detection_idx] = 1
                    amount_bboxes[detection[0]][best_gt_idx] = 1
                else:
                    FP[detection_idx] = 1
            # if the IoU is lower then the detection is a false positive
            else:
                FP[detection_idx] = 1

        # compute cumulative precision and recall
        TP_cumsum = torch.cumsum(TP, dim=0)
        FP_cumsum = torch.cumsum(FP, dim=0)
        recalls = TP_cumsum / (total_true_bboxes + epsilon)
        precisions = torch.divide(TP_cumsum, (TP_cumsum + FP_cumsum + epsilon))
        precisions = torch.cat((torch.tensor([1]), precisions))
        recalls = torch.cat((torch.tensor([0]), recalls))

        # use torch.trapz for numerical integration
        average_precisions.append(torch.trapz(precisions, recalls))

    return sum(average_precisions) / len(average_precisions)

# Modeling

In [8]:
"""
Information about the architectural configuration:
A tuple is structured as (kernel_size, number of filters, stride, padding).
"M" simply represents max-pooling with a 2x2 pool size and 2x2 kernel.
The lis is structured according to the data blocks, and ends with an integer representing the number of repetitions.
"""

# describing convolutional and max-pooling layers, as we as the number of repetitions of convolutional blocks.

architecture_config = [
    (7, 64, 2, 3),  # convolutional block 1
    "M",  # max-pooling layer 1
    (3, 192, 1, 1),  # convolutional block 2
    "M",  # max-pooling layer 2
    (1, 128, 1, 0),  # convolutional block 3
    (3, 256, 1, 1),  # convolutional block 4
    (1, 256, 1, 0),  # convolutional block 5
    (3, 512, 1, 1),  # convolutional block 6
    "M",  # max-pooling layer 3
    [(1, 256, 1, 0), (3, 512, 1, 1), 4],  # convolutional block 7 (repeated 4 times)
    (1, 512, 1, 0),  # convolutional block 8
    (3, 1024, 1, 1),  # convolutional block 9
    "M",  # max-pooling layer 4
    [(1, 512, 1, 0), (3, 1024, 1, 1), 2],  # convolutional block 10 (repeated 2 times)
    (3, 1024, 1, 1),  # convolutional block 11
    (3, 1024, 2, 1),  # convolutional block 12
    (3, 1024, 1, 1),  # convolutional block 13
    (3, 1024, 1, 1),  # convolutional block 14
]


# a convolutional block is defined with Conv2d, BatchNorm2d, and LeakyReLU layers.
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(CNNBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)

    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))

In [9]:
# The YOLOv1 model is defined with convolutional layers and fully connected layers.
class Yolov1(nn.Module):
    def __init__(self, in_channels=3, **kwargs):
        super(Yolov1, self).__init__()
        self.architecture = architecture_config
        self.in_channels = in_channels
        self.darknet = self._create_conv_layers(self.architecture)
        self.fcs = self._create_fcs(**kwargs)

    def forward(self, x):
        x = self.darknet(x)
        return self.fcs(torch.flatten(x, start_dim=1))

    # function to create convolutional layers based on the predefined architecture
    def _create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels

        for x in architecture:
            if type(x) == tuple:
                layers += [
                    CNNBlock(
                        in_channels,
                        x[1],
                        kernel_size=x[0],
                        stride=x[2],
                        padding=x[3],
                    )
                ]
                in_channels = x[1]

            elif type(x) == str:
                layers += [nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))]

            elif type(x) == list:
                conv1 = x[0]
                conv2 = x[1]
                num_repeats = x[2]

                for _ in range(num_repeats):
                    layers += [
                        CNNBlock(
                            in_channels,
                            conv1[1],
                            kernel_size=conv1[0],
                            stride=conv1[2],
                            padding=conv1[3],
                        )
                    ]
                    layers += [
                        CNNBlock(
                            conv1[1],
                            conv2[1],
                            kernel_size=conv2[0],
                            stride=conv2[2],
                            padding=conv2[3],
                        )
                    ]
                    in_channels = conv2[1]

        return nn.Sequential(*layers)

    # function to create fully connected layers based on the input parameters
    # such as grid size, number of boxes, and number of classes
    def _create_fcs(self, split_size, num_boxes, num_classes):
        S, B, C = split_size, num_boxes, num_classes
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024 * S * S, 4096),
            nn.Dropout(0.0),
            nn.LeakyReLU(0.1),
            nn.Linear(4096, S * S * (C + B * 5)),
        )

In [10]:
class YoloLoss(nn.Module):
    def __init__(self, S=7, B=2, C=20):
        super(YoloLoss, self).__init__()
        self.mse = nn.MSELoss(reduction="sum")

        """
        S is the grid size of the image (7)
        B is the number of bounding boxes (2)
        C is the number of classes (in VOC dataset, it's 20)
        """
        self.S = S
        self.B = B
        self.C = C

        # these are YOLO-specific constants, representing the weights
        # for no object loss (lambda_noobj), and box coordinate loss (lambda_coord)
        self.lambda_noobj = 0.5
        self.lambda_coord = 5

    def forward(self, predictions, target):
        # reshape the predictions to the shape (BATCH_SIZE, S*S(C+B*5))
        predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B * 5)

        # calculate IoU for the two bounding boxes
        iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
        iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)

        # get the box with the highest IoU among the two prediction
        # note that bestbox will be 0 or 1, indicating which box is better
        iou_maxes, bestbox = torch.max(ious, dim=0)

        # this represents Iobj_i in the paper
        exists_box = target[..., 20].unsqueeze(3)

        # =================== #
        # FOR BOX COORDINATES #
        # =================== #

        # set the boxes with no objects to zero. Choose one of the two predictions based on the best IoU
        box_predictions = exists_box * (
            (
                bestbox * predictions[..., 26:30]
                + (1 - bestbox) * predictions[..., 21:25]
            )
        )
        box_targets = exists_box * target[..., 21:25]

        # take the square root of the width and height to ensure positive values.
        box_predictions[..., 2:4] = torch.sign(box_predictions[..., 2:4]) * torch.sqrt(
            torch.abs(box_predictions[..., 2:4] + 1e-6)
        )
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])

        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim=-2),
            torch.flatten(box_targets, end_dim=-2),
        )

        # ================== #
        # FOR OBJECT LOSS    #
        # ================== #

        # pred_box represents the confidence score for the box with the highest IoU
        pred_box = (
            bestbox * predictions[..., 25:26] + (1 - bestbox) * predictions[..., 20:21]
        )
        object_loss = self.mse(
            torch.flatten(exists_box * pred_box),
            torch.flatten(exists_box * target[..., 20:21]),
        )

        # ==================== #
        # FOR NO OBJECT LOSS   #
        # ==================== #
        no_object_loss = self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 20:21], start_dim=1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
        )
        no_object_loss += self.mse(
            torch.flatten((1 - exists_box) * predictions[..., 25:26], start_dim=1),
            torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),
        )

        # ================== #
        # FOR CLASS LOSS     #
        # ================== #

        class_loss = self.mse(
            torch.flatten(
                exists_box * predictions[..., :20],
                end_dim=-2,
            ),
            torch.flatten(
                exists_box * target[..., :20],
                end_dim=-2,
            ),
        )

        # calculate the final loss by combining the above losses
        loss = (
            self.lambda_coord * box_loss
            + object_loss
            + self.lambda_noobj * no_object_loss
            + class_loss
        )

        return loss

# Training

In [11]:
# set the random seed for reproducibility
seed = 123
torch.manual_seed(seed)

# Hyperparameters and configurations
# Learning rate for the optimizer
LEARNING_RATE = 2e-5

# Specify whether to use "cuda" (GPU) or "cpu" for training
DEVICE = "cuda"

# Originally 64 in the research paper, but using a smaller batch size due to GPU limitations
BATCH_SIZE = 16

# Number of training epochs
EPOCHS = 300

# Number of worker processes for data loading
NUM_WORKERS = 2

# If True, DataLoader will pin memory to transfer data to the GPU faster
PIN_MEMORY = True

# If False, the training process will not load a pre-trained model
LOAD_MODEL = False

# Specify the file name for the pre-trained model if LOAD_MODEL is True
LOAD_MODEL_FILE = "yolov1.pth.tar"

In [12]:
WIDTH = 448
HEIGHT = 448


def get_train_transforms():
    return A.Compose(
        [
            A.OneOf(
                [
                    A.HueSaturationValue(
                        hue_shift_limit=0.2,
                        sat_shift_limit=0.2,
                        val_shift_limit=0.2,
                        p=0.9,
                    ),
                    A.RandomBrightnessContrast(
                        brightness_limit=0.2, contrast_limit=0.2, p=0.9
                    ),
                ],
                p=0.9,
            ),
            A.ToGray(p=0.01),
            A.HorizontalFlip(p=0.2),
            A.VerticalFlip(p=0.2),
            A.Resize(height=WIDTH, width=WIDTH, p=1),
            # A.Cutout(num_holes=8, max_h_size=64, max_w_size=64, fill_value=0, p=0.5),
            ToTensorV2(p=1.0),
        ],
        p=1.0,
        bbox_params=A.BboxParams(
            format="yolo", min_area=0, min_visibility=0, label_fields=["labels"]
        ),
    )


def get_valid_transforms():
    return A.Compose(
        [
            A.Resize(height=WIDTH, width=WIDTH, p=1.0),
            ToTensorV2(p=1.0),
        ],
        p=1.0,
        bbox_params=A.BboxParams(
            format="yolo", min_area=0, min_visibility=0, label_fields=["labels"]
        ),
    )

In [13]:
class_mapping = {
    "aeroplane": 0,
    "bicycle": 1,
    "bird": 2,
    "boat": 3,
    "bottle": 4,
    "bus": 5,
    "car": 6,
    "cat": 7,
    "chair": 8,
    "cow": 9,
    "diningtable": 10,
    "dog": 11,
    "horse": 12,
    "motorbike": 13,
    "person": 14,
    "pottedplant": 15,
    "sheep": 16,
    "sofa": 17,
    "train": 18,
    "tvmonitor": 19,
}

In [14]:
def train_fn(train_loader, model, optimizer, loss_fn, epoch):
    mean_loss = []
    mean_mAP = []

    total_batches = len(train_loader)

    # update after 20% of the total batches
    display_interval = total_batches // 5

    for batch_idx, (x, y) in enumerate(train_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)

        optimizer.zero_grad()
        out = model(x)
        loss = loss_fn(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pred_boxes, true_boxes = get_bboxes_training(
            out, y, iou_threshold=0.5, threshold=0.4
        )
        mAP = mean_average_precision(
            pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint"
        )

        mean_loss.append(loss.item())
        mean_mAP.append(mAP.item())

        if batch_idx % display_interval == 0 or batch_idx == total_batches - 1:
            print(
                f"Epoch: [{epoch}/{EPOCHS}] \t Iter: [{batch_idx}/{total_batches}] \t Loss: {loss.item():.4f} \t mAP: {mAP.item():.4f}"
            )

    avg_loss = sum(mean_loss) / len(mean_loss)
    avg_mAP = sum(mean_mAP) / len(mean_mAP)
    print(colored(f"Train \t loss: {avg_loss:.4f} \t mAP: {avg_mAP:.4f}", "green"))

    return avg_mAP

In [16]:
def test_fn(test_loader, model, loss_fn, epoch):
    model.eval()
    mean_loss = []
    mean_mAP = []

    for batch_idx, (x, y) in enumerate(test_loader):
        x, y = x.to(DEVICE), y.to(DEVICE)
        out = model(x)
        loss = loss_fn(out, y)

        pred_boxes, true_boxes = get_bboxes_training(
            out, y, iou_threshold=0.5, threshold=0.4
        )
        mAP = mean_average_precision(
            pred_boxes, true_boxes, iou_threshold=0.5, box_format="midpoint"
        )

        mean_loss.append(loss.item())
        mean_mAP.append(mAP.item())

    avg_loss = sum(mean_loss) / len(mean_loss)
    avg_mAP = sum(mean_mAP) / len(mean_mAP)
    print(colored(f"Test \t loss: {avg_loss:3.10f} \t mAP: {avg_mAP:3.10f}", "yellow"))

    model.train()

    return avg_mAP