In [1]:
import torch
import torch.nn as nn
import torchvision
import math
import torch.nn.functional as F
import os
import random
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def get_iou(box1, box2):
    """
    Compute Intersection over Union (IoU) between two sets of boxes.
    :param box1: (Tensor of shape N x 4)
    :param box2: (Tensor of shape M x 4)
    :return: IoU matrix of shape N x M
    """
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])  # N
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])  # M

    x1 = torch.max(box1[:, None, 0], box2[:, 0])  # N x M
    y1 = torch.max(box1[:, None, 1], box2[:, 1])  # N x M
    x2 = torch.min(box1[:, None, 2], box2[:, 2])  # N x M
    y2 = torch.min(box1[:, None, 3], box2[:, 3])  # N x M

    inter_area = (x2 - x1).clamp(min=0) * (y2 - y1).clamp(min=0)  # N x M
    union_area = area1[:, None] + area2 - inter_area  # N x M
    return inter_area / union_area  # N x M


def boxes_to_transformation_targets(gt_boxes, anchor_boxes):
    """
    Compute transformation targets (tx, ty, tw, th) for anchor boxes based on ground truth boxes.
    :param gt_boxes: (N, 4) Ground truth bounding boxes (x1, y1, x2, y2)
    :param anchor_boxes: (N, 4) Anchor boxes (x1, y1, x2, y2)
    :return: (N, 4) Transformation targets (tx, ty, tw, th) for the anchor boxes
    """

    # Get center_x, center_y, width, height for anchor boxes
    anchor_widths = anchor_boxes[:, 2] - anchor_boxes[:, 0]
    anchor_heights = anchor_boxes[:, 3] - anchor_boxes[:, 1]
    anchor_center_x = anchor_boxes[:, 0] + 0.5 * anchor_widths
    anchor_center_y = anchor_boxes[:, 1] + 0.5 * anchor_heights

    # Get center_x, center_y, width, height for ground truth boxes
    gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0]
    gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1]
    gt_center_x = gt_boxes[:, 0] + 0.5 * gt_widths
    gt_center_y = gt_boxes[:, 1] + 0.5 * gt_heights

    # Compute transformation targets (tx, ty, tw, th)
    dx = (gt_center_x - anchor_center_x) / anchor_widths
    dy = (gt_center_y - anchor_center_y) / anchor_heights
    dw = torch.log(gt_widths / anchor_widths)
    dh = torch.log(gt_heights / anchor_heights)

    # Stack the targets
    transformation_targets = torch.stack((dx, dy, dw, dh), dim=1)
    return transformation_targets

def apply_regression_pred_to_anchors(box_transform_pred, anchors):
    """
    Apply predicted transformations to anchors to get predicted bounding boxes.
    :param box_transform_pred: (num_anchors, num_classes, 4) Transformation predictions (dx, dy, dw, dh)
    :param anchors: (num_anchors, 4) Anchor boxes (x1, y1, x2, y2)
    :return: pred_boxes: (num_anchors, num_classes, 4) Predicted bounding boxes (x1, y1, x2, y2)
    """
    box_transform_pred = box_transform_pred.reshape(
        box_transform_pred.size(0), -1, 4)

    # Get width, height, and center of anchors
    w = anchors[:, 2] - anchors[:, 0]
    h = anchors[:, 3] - anchors[:, 1]
    center_x = anchors[:, 0] + 0.5 * w
    center_y = anchors[:, 1] + 0.5 * h

    dx = box_transform_pred[..., 0]
    dy = box_transform_pred[..., 1]
    dw = box_transform_pred[..., 2]
    dh = box_transform_pred[..., 3]

    # Clamp the dw, dh values to avoid large values
    dw = torch.clamp(dw, max=math.log(1000.0 / 16))
    dh = torch.clamp(dh, max=math.log(1000.0 / 16))

    # Compute predicted center and dimensions
    pred_center_x = dx * w[:, None] + center_x[:, None]
    pred_center_y = dy * h[:, None] + center_y[:, None]
    pred_w = torch.exp(dw) * w[:, None]
    pred_h = torch.exp(dh) * h[:, None]

    # Convert predicted center and dimensions to (x1, y1, x2, y2)
    pred_box_x1 = pred_center_x - 0.5 * pred_w
    pred_box_y1 = pred_center_y - 0.5 * pred_h
    pred_box_x2 = pred_center_x + 0.5 * pred_w
    pred_box_y2 = pred_center_y + 0.5 * pred_h

    # Stack to get final predicted boxes
    pred_boxes = torch.stack((pred_box_x1, pred_box_y1, pred_box_x2, pred_box_y2), dim=2)
    return pred_boxes


def sample_positive_negative(labels, positive_count, total_count):
    # Sample positive and negative proposals for training
    positive = torch.where(labels >= 1)[0]
    negative = torch.where(labels == 0)[0]
    num_pos = positive_count
    num_pos = min(positive.numel(), num_pos)
    num_neg = total_count - num_pos
    num_neg = min(negative.numel(), num_neg)
    perm_positive_idxs = torch.randperm(positive.numel(),
                                        device=positive.device)[:num_pos]
    perm_negative_idxs = torch.randperm(negative.numel(),
                                        device=negative.device)[:num_neg]
    pos_idxs = positive[perm_positive_idxs]
    neg_idxs = negative[perm_negative_idxs]
    sampled_pos_idx_mask = torch.zeros_like(labels, dtype=torch.bool)
    sampled_neg_idx_mask = torch.zeros_like(labels, dtype=torch.bool)
    sampled_pos_idx_mask[pos_idxs] = True
    sampled_neg_idx_mask[neg_idxs] = True
    return sampled_neg_idx_mask, sampled_pos_idx_mask


def clamp_boxes_to_image_boundary(boxes, image_shape):
    boxes_x1 = boxes[..., 0]
    boxes_y1 = boxes[..., 1]
    boxes_x2 = boxes[..., 2]
    boxes_y2 = boxes[..., 3]
    height, width = image_shape[-2:]
    boxes_x1 = boxes_x1.clamp(min=0, max=width)
    boxes_x2 = boxes_x2.clamp(min=0, max=width)
    boxes_y1 = boxes_y1.clamp(min=0, max=height)
    boxes_y2 = boxes_y2.clamp(min=0, max=height)
    boxes = torch.cat((
        boxes_x1[..., None],
        boxes_y1[..., None],
        boxes_x2[..., None],
        boxes_y2[..., None]),
        dim=-1)
    return boxes

class RegionProposalNetwork(nn.Module):
    """
    RPN with following layers on the feature map
        1. 3x3 conv layer followed by Relu
        2. 1x1 classification conv with num_anchors(num_scales x num_aspect_ratios) output channels
        3. 1x1 classification conv with 4 x num_anchors output channels

    Classification is done via one value indicating probability of foreground
    with sigmoid applied during inference
    """

    def __init__(self, in_channels, scales, aspect_ratios):
        super(RegionProposalNetwork, self).__init__()
        self.scales = scales
        self.aspect_ratios = aspect_ratios
        self.low_iou_threshold = 0.2
        self.high_iou_threshold = 0.85
        self.rpn_nms_threshold = 0.2
        self.rpn_batch_size = 1
        self.rpn_pos_count = int(1)
        self.rpn_topk = 15 if self.training else 10
        self.rpn_prenms_topk = 15 if self.training \
            else 10
        self.num_anchors = len(self.scales) * len(self.aspect_ratios)

        # 3x3 conv layer
        self.rpn_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)

        # 1x1 classification conv layer
        self.cls_layer = nn.Conv2d(in_channels, self.num_anchors, kernel_size=1, stride=1)

        # 1x1 regression
        self.bbox_reg_layer = nn.Conv2d(in_channels, self.num_anchors * 4, kernel_size=1, stride=1)


    def generate_anchors(self, image, feat):
      """
      Generate anchor boxes for a feature map based on predefined scales and aspect ratios.
      First, zero-centered anchors are created and then shifted across the feature map.
      The anchor centers are at the top-left corners of each feature map cell.

      :param image: (N, C, H, W) Tensor representing the image dimensions.
      :param feat: (N, C_feat, H_feat, W_feat) Tensor representing the feature map dimensions.
      :return: Anchors of shape (H_feat * W_feat * num_anchors_per_location, 4)
      """
      grid_h, grid_w = feat.shape[-2:]
      image_h, image_w = image.shape[-2:]

      # Compute stride for each grid cell
      stride_h = image_h // grid_h
      stride_w = image_w // grid_w

      # Convert scales and aspect ratios to tensors
      scales = torch.as_tensor(self.scales, dtype=feat.dtype, device=feat.device)
      aspect_ratios = torch.as_tensor(self.aspect_ratios, dtype=feat.dtype, device=feat.device)

      # Calculate height and width ratios based on aspect ratios
      h_ratios = torch.sqrt(aspect_ratios)
      w_ratios = 1 / h_ratios

      # Compute anchor widths and heights based on scales and aspect ratios
      ws = (w_ratios[:, None] * scales[None, :]).view(-1)
      hs = (h_ratios[:, None] * scales[None, :]).view(-1)

      # Create zero-centered anchors (x1, y1, x2, y2)
      base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
      base_anchors = base_anchors.round()

      # Compute shifts in x and y directions for the feature map
      shifts_x = torch.arange(0, grid_w, dtype=torch.int32, device=feat.device) * stride_w
      shifts_y = torch.arange(0, grid_h, dtype=torch.int32, device=feat.device) * stride_h

      # Generate a grid of shifts
      shifts_y, shifts_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij")

      shifts_x = shifts_x.reshape(-1)
      shifts_y = shifts_y.reshape(-1)

      # Combine x and y shifts
      shifts = torch.stack((shifts_x, shifts_y, shifts_x, shifts_y), dim=1)

      # Add the shifts to the base anchors to create final anchors
      anchors = (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4))
      anchors = anchors.reshape(-1, 4)

      return anchors

    def assign_targets_to_anchors(self, anchors, gt_boxes):
        """
        Assign ground truth boxes to anchors based on IOU and generate classification labels.
        - Label 1 for anchors with IOU above the high threshold (positive matches).
        - Label 0 for anchors with IOU below the low threshold (background).
        - Label -1 for anchors with IOU between the low and high thresholds (ignored).

        :param anchors: (num_anchors, 4) Tensor containing anchor box coordinates.
        :param gt_boxes: (num_gt_boxes, 4) Tensor containing ground truth box coordinates.
        :return:
            labels: (num_anchors) Tensor of labels {-1, 0, 1} for each anchor.
            matched_gt_boxes: (num_anchors, 4) Tensor of assigned ground truth box coordinates for each anchor.
        """
        # Compute the IOU matrix between anchors and ground truth boxes
        iou_matrix = get_iou(gt_boxes, anchors)

        # Find the ground truth box with the highest IOU for each anchor
        best_match_iou, best_match_gt_idx = iou_matrix.max(dim=0)

        # Save the initial GT indices for later adjustments
        best_match_gt_idx_pre_thresholding = best_match_gt_idx.clone()

        # Apply thresholds to classify anchors:
        below_low_threshold = best_match_iou < self.low_iou_threshold
        between_thresholds = (best_match_iou >= self.low_iou_threshold) & (best_match_iou < self.high_iou_threshold)

        # Mark anchors with low IOU as background (-1), and those between thresholds as ignored (-2)
        best_match_gt_idx[below_low_threshold] = -1
        best_match_gt_idx[between_thresholds] = -2

        # Find anchors with the highest IOU for each ground truth box (to handle multiple matches)
        best_anchor_iou_for_gt, _ = iou_matrix.max(dim=1)
        gt_pred_pair_with_highest_iou = torch.where(iou_matrix == best_anchor_iou_for_gt[:, None])

        # Get the indices of anchors that should be updated (positive matches)
        pred_inds_to_update = gt_pred_pair_with_highest_iou[1]

        # Update the GT index for these anchors to the best match
        best_match_gt_idx[pred_inds_to_update] = best_match_gt_idx_pre_thresholding[pred_inds_to_update]

        # Retrieve the matched ground truth boxes for each anchor (background anchors will be assigned the 0th GT box)
        matched_gt_boxes = gt_boxes[best_match_gt_idx.clamp(min=0)]

        # Set classification labels for anchors
        labels = best_match_gt_idx >= 0  # foreground anchors (1)
        labels = labels.to(dtype=torch.float32)

        # Set background anchors (0)
        background_anchors = best_match_gt_idx == -1
        labels[background_anchors] = 0.0

        # Set ignored anchors (-1)
        ignored_anchors = best_match_gt_idx == -2
        labels[ignored_anchors] = -1.0

        return labels, matched_gt_boxes

    def filter_proposals(self, proposals, cls_scores, image_shape):
        """
        This method does three kinds of filtering/modifications
        1. Pre NMS topK filtering
        2. Make proposals valid by clamping coordinates(0, width/height)
        2. Small Boxes filtering based on width and height
        3. NMS
        4. Post NMS topK filtering
        :param proposals: (num_anchors_in_image, 4)
        :param cls_scores: (num_anchors_in_image, 4) these are cls logits
        :param image_shape: resized image shape needed to clip proposals to image boundary
        :return: proposals and cls_scores: (num_filtered_proposals, 4) and (num_filtered_proposals)
        """
        # Pre NMS Filtering
        cls_scores = cls_scores.reshape(-1)
        cls_scores = torch.sigmoid(cls_scores)
        _, top_n_idx = cls_scores.topk(min(self.rpn_prenms_topk, len(cls_scores)))

        cls_scores = cls_scores[top_n_idx]
        proposals = proposals[top_n_idx]

        # Clamp boxes to image boundary
        proposals = clamp_boxes_to_image_boundary(proposals, image_shape)

        # Small boxes based on width and height filtering
        min_size = 16
        ws, hs = proposals[:, 2] - proposals[:, 0], proposals[:, 3] - proposals[:, 1]
        keep = (ws >= min_size) & (hs >= min_size)
        keep = torch.where(keep)[0]
        proposals = proposals[keep]
        cls_scores = cls_scores[keep]

        # NMS based on objectness scores
        keep_mask = torch.zeros_like(cls_scores, dtype=torch.bool)
        keep_indices = torch.ops.torchvision.nms(proposals, cls_scores, self.rpn_nms_threshold)
        keep_mask[keep_indices] = True
        keep_indices = torch.where(keep_mask)[0]
        # Sort by objectness
        post_nms_keep_indices = keep_indices[cls_scores[keep_indices].sort(descending=True)[1]]

        # Post NMS topk filtering
        proposals, cls_scores = (proposals[post_nms_keep_indices[:self.rpn_topk]],
                                 cls_scores[post_nms_keep_indices[:self.rpn_topk]])

        return proposals, cls_scores

    def forward(self, image, feat, target=None):
        """
        Main method for RPN does the following:
        1. Call RPN specific conv layers to generate classification and
            bbox transformation predictions for anchors
        2. Generate anchors for entire image
        3. Transform generated anchors based on predicted bbox transformation to generate proposals
        4. Filter proposals
        5. For training additionally we do the following:
            a. Assign target ground truth labels and boxes to each anchors
            b. Sample positive and negative anchors
            c. Compute classification loss using sampled pos/neg anchors
            d. Compute Localization loss using sampled pos anchors
        :param image:
        :param feat:
        :param target:
        :return:
        """
        # Call RPN layers
        rpn_feat = nn.ReLU()(self.rpn_conv(feat))
        cls_scores = self.cls_layer(rpn_feat)
        box_transform_pred = self.bbox_reg_layer(rpn_feat)

        # Generate anchors
        anchors = self.generate_anchors(image, feat)

        # Reshape classification scores to be (Batch Size * H_feat * W_feat * Number of Anchors Per Location, 1)
        # cls_score -> (Batch_Size, Number of Anchors per location, H_feat, W_feat)
        number_of_anchors_per_location = cls_scores.size(1)
        cls_scores = cls_scores.permute(0, 2, 3, 1)
        cls_scores = cls_scores.reshape(-1, 1)
        # cls_score -> (Batch_Size*H_feat*W_feat*Number of Anchors per location, 1)

        # Reshape bbox predictions to be (Batch Size * H_feat * W_feat * Number of Anchors Per Location, 4)
        # box_transform_pred -> (Batch_Size, Number of Anchors per location*4, H_feat, W_feat)
        box_transform_pred = box_transform_pred.view(
            box_transform_pred.size(0),
            number_of_anchors_per_location,
            4,
            rpn_feat.shape[-2],
            rpn_feat.shape[-1])
        box_transform_pred = box_transform_pred.permute(0, 3, 4, 1, 2)
        box_transform_pred = box_transform_pred.reshape(-1, 4)
        # box_transform_pred -> (Batch_Size*H_feat*W_feat*Number of Anchors per location, 4)

        # Transform generated anchors according to box transformation prediction
        proposals = apply_regression_pred_to_anchors(
            box_transform_pred.detach().reshape(-1, 1, 4),
            anchors)
        proposals = proposals.reshape(proposals.size(0), 4)

        proposals, scores = self.filter_proposals(proposals, cls_scores.detach(), image.shape)
        rpn_output = {
            'proposals': proposals,
            'scores': scores
        }
        if not self.training or target is None:
            # If we are not training no need to do anything
            return rpn_output
        else:
            # Assign gt box and label for each anchor
            labels_for_anchors, matched_gt_boxes_for_anchors = self.assign_targets_to_anchors(
                anchors,
                target['bboxes'][0])

            # Based on gt assignment above, get regression target for the anchors
            # matched_gt_boxes_for_anchors -> (Number of anchors in image, 4)
            # anchors -> (Number of anchors in image, 4)
            regression_targets = boxes_to_transformation_targets(matched_gt_boxes_for_anchors, anchors)

            ####### Sampling positive and negative anchors ####
            # Our labels were {fg:1, bg:0, to_be_ignored:-1}
            sampled_neg_idx_mask, sampled_pos_idx_mask = sample_positive_negative(
                labels_for_anchors,
                positive_count=self.rpn_pos_count,
                total_count=self.rpn_batch_size)

            sampled_idxs = torch.where(sampled_pos_idx_mask | sampled_neg_idx_mask)[0]

            localization_loss = (
                    torch.nn.functional.smooth_l1_loss(
                        box_transform_pred[sampled_pos_idx_mask],
                        regression_targets[sampled_pos_idx_mask],
                        beta=1/9,
                        reduction="sum",
                    )
                    / (sampled_idxs.numel())
            )

            cls_loss = torch.nn.functional.binary_cross_entropy_with_logits(cls_scores[sampled_idxs].flatten(),
                                                                            labels_for_anchors[sampled_idxs].flatten())

            rpn_output['rpn_class_loss'] = cls_loss
            rpn_output['rpn_local_loss'] = localization_loss
            return rpn_output


class ROIHead(nn.Module):
    """
    ROI head on top of ROI pooling layer for generating
    classification and box transformation predictions
    We have two fc layers followed by a classification fc layer
    and a bbox regression fc layer
    """

    def __init__(self, num_classes, in_channels=512):
        super(ROIHead, self).__init__()
        self.num_classes = num_classes
        self.roi_batch_size = 2
        self.roi_pos_count = int(1*self.roi_batch_size)
        self.iou_threshold = 0.85
        self.low_bg_iou = 0.1
        self.nms_threshold = 0.2
        self.topK_detections = 15
        self.low_score_threshold = 0.05
        self.pool_size = 7
        self.fc_inner_dim = 1024

        self.fc6 = nn.Linear(in_channels * self.pool_size * self.pool_size, self.fc_inner_dim)
        self.fc7 = nn.Linear(self.fc_inner_dim, self.fc_inner_dim)
        self.cls_layer = nn.Linear(self.fc_inner_dim, self.num_classes)
        self.bbox_reg_layer = nn.Linear(self.fc_inner_dim, self.num_classes * 4)

        torch.nn.init.normal_(self.cls_layer.weight, std=0.01)
        torch.nn.init.constant_(self.cls_layer.bias, 0)

        torch.nn.init.normal_(self.bbox_reg_layer.weight, std=0.001)
        torch.nn.init.constant_(self.bbox_reg_layer.bias, 0)

    def assign_target_to_proposals(self, proposals, gt_boxes, gt_labels):
        """
        Assign ground truth boxes to proposals based on IOU, and generate classification labels.
        - Labels are assigned as follows:
            - 1 for proposals matching a ground truth box.
            - 0 for proposals classified as background.
            - -1 for proposals to be ignored.

        :param proposals: (number_of_proposals, 4) Tensor of proposed box coordinates.
        :param gt_boxes: (number_of_gt_boxes, 4) Tensor of ground truth box coordinates.
        :param gt_labels: (number_of_gt_boxes) Tensor of class labels for each ground truth box.
        :return:
            labels: (number_of_proposals) Tensor of labels {-1, 0, 1} for each proposal.
            matched_gt_boxes: (number_of_proposals, 4) Tensor of assigned ground truth box coordinates for each proposal.
        """
        # Compute IOU between proposals and ground truth boxes
        iou_matrix = get_iou(gt_boxes, proposals)

        # Find the best matching ground truth box for each proposal
        best_match_iou, best_match_gt_idx = iou_matrix.max(dim=0)

        # Identify proposals with low IOU (background or ignored)
        background_proposals = (best_match_iou < self.iou_threshold) & (best_match_iou >= self.low_bg_iou)
        ignored_proposals = best_match_iou < self.low_bg_iou

        # Mark low IOU proposals as background (-1) or ignored (-2)
        best_match_gt_idx[background_proposals] = -1
        best_match_gt_idx[ignored_proposals] = -2

        # Assign ground truth boxes to proposals (even background proposals)
        matched_gt_boxes_for_proposals = gt_boxes[best_match_gt_idx.clamp(min=0)]

        # Assign class labels based on the best matched ground truth box
        labels = gt_labels[best_match_gt_idx.clamp(min=0)].to(dtype=torch.int64)

        # Set background proposals to label 0
        labels[background_proposals] = 0

        # Set ignored proposals to label -1
        labels[ignored_proposals] = -1

        return labels, matched_gt_boxes_for_proposals

    def forward(self, feat, proposals, image_shape, target):
        """
        Main method for ROI head that does the following:
        1. If training assign target boxes and labels to all proposals
        2. If training sample positive and negative proposals
        3. If training get bbox transformation targets for all proposals based on assignments
        4. Get ROI Pooled features for all proposals
        5. Call fc6, fc7 and classification and bbox transformation fc layers
        6. Compute classification and localization loss

        :param feat:
        :param proposals:
        :param image_shape:
        :param target:
        :return:
        """
        if self.training and target is not None:
            # Add ground truth to proposals
            proposals = torch.cat([proposals, target['bboxes'][0]], dim=0)

            gt_boxes = target['bboxes'][0]
            gt_labels = target['labels'][0]

            labels, matched_gt_boxes_for_proposals = self.assign_target_to_proposals(proposals, gt_boxes, gt_labels)

            sampled_neg_idx_mask, sampled_pos_idx_mask = sample_positive_negative(labels,
                                                                                  positive_count=self.roi_pos_count,
                                                                                  total_count=self.roi_batch_size)

            sampled_idxs = torch.where(sampled_pos_idx_mask | sampled_neg_idx_mask)[0]

            # Keep only sampled proposals
            proposals = proposals[sampled_idxs]
            labels = labels[sampled_idxs]
            matched_gt_boxes_for_proposals = matched_gt_boxes_for_proposals[sampled_idxs]
            regression_targets = boxes_to_transformation_targets(matched_gt_boxes_for_proposals, proposals)
            # regression_targets -> (sampled_training_proposals, 4)
            # matched_gt_boxes_for_proposals -> (sampled_training_proposals, 4)

        # Get desired scale to pass to roi pooling function
        size = feat.shape[-2:]
        possible_scales = []
        for s1, s2 in zip(size, image_shape):
            approx_scale = float(s1) / float(s2)
            scale = 2 ** float(torch.tensor(approx_scale).log2().round())
            possible_scales.append(scale)
        assert possible_scales[0] == possible_scales[1]

        # ROI pooling and call all layers for prediction
        proposal_roi_pool_feats = torchvision.ops.roi_pool(feat, [proposals],
                                                           output_size=self.pool_size,
                                                           spatial_scale=possible_scales[0])
        proposal_roi_pool_feats = proposal_roi_pool_feats.flatten(start_dim=1)
        box_fc_6 = torch.nn.functional.relu(self.fc6(proposal_roi_pool_feats))
        box_fc_7 = torch.nn.functional.relu(self.fc7(box_fc_6))
        cls_scores = self.cls_layer(box_fc_7)
        box_transform_pred = self.bbox_reg_layer(box_fc_7)
        # cls_scores -> (proposals, num_classes)
        # box_transform_pred -> (proposals, num_classes * 4)

        num_boxes, num_classes = cls_scores.shape
        box_transform_pred = box_transform_pred.reshape(num_boxes, num_classes, 4)
        frcnn_output = {}
        if self.training and target is not None:
            classification_loss = torch.nn.functional.cross_entropy(cls_scores, labels)

            # Compute localization loss only for non-background labelled proposals
            fg_proposals_idxs = torch.where(labels > 0)[0]
            # Get class labels for these positive proposals
            fg_cls_labels = labels[fg_proposals_idxs]

            localization_loss = torch.nn.functional.smooth_l1_loss(
                box_transform_pred[fg_proposals_idxs, fg_cls_labels],
                regression_targets[fg_proposals_idxs],
                beta=1/9,
                reduction="sum",
            )
            localization_loss = localization_loss / labels.numel()
            frcnn_output['frcnn_class_loss'] = classification_loss
            frcnn_output['frcnn_local_loss'] = localization_loss

        if self.training:
            return frcnn_output
        else:
            device = cls_scores.device
            # Apply transformation predictions to proposals
            pred_boxes = apply_regression_pred_to_anchors(box_transform_pred, proposals)
            pred_scores = torch.nn.functional.softmax(cls_scores, dim=-1)

            # Clamp box to image boundary
            pred_boxes = clamp_boxes_to_image_boundary(pred_boxes, image_shape)

            # create labels for each prediction
            pred_labels = torch.arange(num_classes, device=device)
            pred_labels = pred_labels.view(1, -1).expand_as(pred_scores)

            # remove predictions with the background label
            pred_boxes = pred_boxes[:, 1:]
            pred_scores = pred_scores[:, 1:]
            pred_labels = pred_labels[:, 1:]

            # pred_boxes -> (number_proposals, num_classes-1, 4)
            # pred_scores -> (number_proposals, num_classes-1)
            # pred_labels -> (number_proposals, num_classes-1)

            # batch everything, by making every class prediction be a separate instance
            pred_boxes = pred_boxes.reshape(-1, 4)
            pred_scores = pred_scores.reshape(-1)
            pred_labels = pred_labels.reshape(-1)

            pred_boxes, pred_labels, pred_scores = self.filter_predictions(pred_boxes, pred_labels, pred_scores)
            frcnn_output['boxes'] = pred_boxes
            frcnn_output['scores'] = pred_scores
            frcnn_output['labels'] = pred_labels
            return frcnn_output

    def filter_predictions(self, pred_boxes, pred_labels, pred_scores):
        """
        Method to filter predictions by applying the following in order:
        1. Filter low scoring boxes
        2. Remove small size boxes
        3. NMS for each class separately
        4. Keep only topK detections
        :param pred_boxes:
        :param pred_labels:
        :param pred_scores:
        :return:
        """
        # remove low scoring boxes
        keep = torch.where(pred_scores > self.low_score_threshold)[0]
        pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]

        # Remove small boxes
        min_size = 16
        ws, hs = pred_boxes[:, 2] - pred_boxes[:, 0], pred_boxes[:, 3] - pred_boxes[:, 1]
        keep = (ws >= min_size) & (hs >= min_size)
        keep = torch.where(keep)[0]
        pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]

        # Class wise nms
        keep_mask = torch.zeros_like(pred_scores, dtype=torch.bool)
        for class_id in torch.unique(pred_labels):
            curr_indices = torch.where(pred_labels == class_id)[0]
            curr_keep_indices = torch.ops.torchvision.nms(pred_boxes[curr_indices],
                                                          pred_scores[curr_indices],
                                                          self.nms_threshold)
            keep_mask[curr_indices[curr_keep_indices]] = True
        keep_indices = torch.where(keep_mask)[0]
        post_nms_keep_indices = keep_indices[pred_scores[keep_indices].sort(descending=True)[1]]
        keep = post_nms_keep_indices[:self.topK_detections]
        pred_boxes, pred_scores, pred_labels = pred_boxes[keep], pred_scores[keep], pred_labels[keep]
        return pred_boxes, pred_labels, pred_scores


class FasterRCNN(nn.Module):
    def __init__(self, num_classes, scales, aspect_ratios):
        super(FasterRCNN, self).__init__()
        vgg16 = torchvision.models.vgg16(pretrained=True)
        self.backbone = vgg16.features[:-1]
        self.rpn = RegionProposalNetwork(512,
                                         scales=[128, 256, 512],
                                         aspect_ratios=[0.5, 1, 2])
        self.roi_head = ROIHead(num_classes, in_channels=512)
        for layer in self.backbone[:10]:
            for p in layer.parameters():
                p.requires_grad = False

    def forward(self, image, target=None):
        old_shape = image.shape[-2:]

        # Call backbone
        feat = self.backbone(image)

        # Call RPN and get proposals
        rpn_output = self.rpn(image, feat, target)
        proposals = rpn_output['proposals']

        # Call ROI head and convert proposals to boxes
        frcnn_output = self.roi_head(feat, proposals, image.shape[-2:], target)

        return rpn_output, frcnn_output

Training Code tools

In [None]:
def get_image_info(image_directory, annotation_directory, label2idx):
    im_infos = []

    # Iterate over all files in the image directory
    for filename in os.listdir(image_directory):
        if filename.endswith('.jpg') or filename.endswith('.png'):
            img_id = filename.split('.')[0]
            img_path = os.path.join(image_directory, filename)

            # Read image to get dimensions
            image = cv2.imread(img_path)
            height, width, _ = image.shape

            # Initialize detections list
            detections = []

            # Read corresponding annotation file
            annotation_file = os.path.join(annotation_directory, f"{img_id}.txt")
            if os.path.exists(annotation_file):
                with open(annotation_file, 'r') as file:
                    for line in file:
                        parts = line.strip().split()
                        if len(parts) == 5:
                            class_id = int(parts[0])  # Original class ID from annotation

                            # Map class ID to label index
                            if class_id in range(len(label2idx) - 1):  # Check if class_id is valid
                                mapped_label = label2idx[classes[class_id + 1]]  # Adjust index for 'background'

                                x_center = float(parts[1]) * width
                                y_center = float(parts[2]) * height
                                box_width = float(parts[3]) * width
                                box_height = float(parts[4]) * height

                                # Calculate bounding box coordinates
                                x_min = int(x_center - box_width / 2)
                                y_min = int(y_center - box_height / 2)
                                x_max = int(x_center + box_width / 2)
                                y_max = int(y_center + box_height / 2)

                                # Add detection with mapped label
                                detections.append({
                                    'label': mapped_label,
                                    'bbox': [x_min, y_min, x_max, y_max]
                                })

            # Append image info to the list
            im_info = {
                'img_id': img_id,
                'filename': img_path,
                'width': width,
                'height': height,
                'detections': detections
            }
            im_infos.append(im_info)

    return im_infos

In [None]:
im_dir = '/content/drive/MyDrive/Colab Notebooks/ir_images'
label_dir = '/content/drive/MyDrive/Colab Notebooks/ir_labels'


classes = ['person', 'car']
classes = ['background'] + classes  # Background is index 0
label2idx = {classes[idx]: idx for idx in range(len(classes))}
idx2label = {idx: classes[idx] for idx in range(len(classes))}
print(label2idx)

get_image_info(im_dir, label_dir, label2idx)

{'background': 0, 'person': 1, 'car': 2}


[{'img_id': 'video_frame_001061',
  'filename': '/content/drive/MyDrive/Colab Notebooks/ir_images/video_frame_001061.jpg',
  'width': 640,
  'height': 512,
  'detections': [{'label': 1, 'bbox': [89, 248, 108, 303]},
   {'label': 1, 'bbox': [404, 249, 422, 323]},
   {'label': 1, 'bbox': [390, 257, 409, 316]},
   {'label': 1, 'bbox': [468, 284, 482, 309]},
   {'label': 1, 'bbox': [573, 266, 597, 326]},
   {'label': 1, 'bbox': [609, 251, 633, 325]}]},
 {'img_id': 'video_frame_002122',
  'filename': '/content/drive/MyDrive/Colab Notebooks/ir_images/video_frame_002122.jpg',
  'width': 640,
  'height': 512,
  'detections': [{'label': 1, 'bbox': [65, 248, 94, 303]},
   {'label': 1, 'bbox': [109, 259, 126, 302]}]},
 {'img_id': 'video_frame_000001',
  'filename': '/content/drive/MyDrive/Colab Notebooks/ir_images/video_frame_000001.jpg',
  'width': 640,
  'height': 512,
  'detections': [{'label': 1, 'bbox': [147, 254, 164, 304]},
   {'label': 1, 'bbox': [571, 267, 588, 318]},
   {'label': 1, 'bb

Data Analysis Tools

In [None]:
class IRDataset(Dataset):
    def __init__(self, split, im_dir, ann_dir, split_ratio=0.8):
        self.split = split
        self.im_dir = im_dir
        self.ann_dir = ann_dir

        # Define the classes
        classes = ['person', 'car']
        classes = sorted(classes)
        classes = ['background'] + classes  # Background is index 0

        self.label2idx = {classes[idx]: idx for idx in range(len(classes))}
        self.idx2label = {idx: classes[idx] for idx in range(len(classes))}
        print(self.idx2label)

        # Use the function to load image information
        self.images_info = get_image_info(im_dir, ann_dir, self.label2idx)

        '''# Split the dataset into training and validation
        split_index = int(len(self.images_info) * split_ratio)
        if split == 'train':
            self.images_info = self.images_info[:split_index]
        elif split == 'val':
            self.images_info = self.images_info[split_index:]
        else:
            raise ValueError("Split must be either 'train' or 'val'")'''

    def __len__(self):
        return len(self.images_info)

    def __getitem__(self, index):
        im_info = self.images_info[index]
        im = Image.open(im_info['filename'])

        im_tensor = torchvision.transforms.ToTensor()(im)

        targets = {}
        targets['bboxes'] = torch.as_tensor([detection['bbox'] for detection in im_info['detections']], dtype=torch.float32)
        targets['labels'] = torch.as_tensor([detection['label'] for detection in im_info['detections']], dtype=torch.int64)

        return im_tensor, targets, im_info['filename']

In [None]:
# Create datasets
train_dataset = IRDataset('train', im_dir, label_dir)
val_dataset = IRDataset('val', im_dir, label_dir)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

# Print dataset sizes
print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(val_dataset)}')

{0: 'background', 1: 'car', 2: 'person'}
{0: 'background', 1: 'car', 2: 'person'}
Training dataset size: 9
Validation dataset size: 9


In [None]:
faster_rcnn_model = faster_rcnn_model = FasterRCNN(num_classes=3, scales=[64, 128, 256], aspect_ratios=[0.5, 1, 2])


train_dataset = IRDataset('train', im_dir='/content/drive/MyDrive/Colab Notebooks/ir_images',
                  ann_dir='/content/drive/MyDrive/Colab Notebooks/ir_labels')

train_loader = DataLoader(train_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=4)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:06<00:00, 87.9MB/s]


{0: 'background', 1: 'car', 2: 'person'}




In [None]:
def train(model, custom_dataset, train_loader):

    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if device == 'cuda':
        torch.cuda.manual_seed_all(seed)

    model.train()
    model.to(device)

    if not os.path.exists('frcnn'):
        os.mkdir('frcnn')

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    acc_steps = 1
    num_epochs =5
    step_count = 1

    # Lists to store losses
    rpn_classification_loss_history = []
    rpn_localization_loss_history = []
    frcnn_classification_loss_history = []
    frcnn_localization_loss_history = []
    mean_ap_history = []
    all_aps_history = []

    gts = []
    preds = []

    for i in range(num_epochs):
        rpn_classification_losses = []
        rpn_localization_losses = []
        frcnn_classification_losses = []
        frcnn_localization_losses = []
        optimizer.zero_grad()

        for im, target, fname in tqdm(train_loader):
            im = im.float().to(device)

            #target_boxes = target['bboxes'].float().to(device)[0]
            #target_labels = target['labels'].long().to(device)[0]

            target['bboxes'] = target['bboxes'].float().to(device)
            target['labels'] = target['labels'].long().to(device)
            rpn_output, frcnn_output = model(im, target)

            rpn_loss = rpn_output['rpn_class_loss'] + rpn_output['rpn_local_loss']
            frcnn_loss = frcnn_output['frcnn_class_loss'] + frcnn_output['frcnn_local_loss']
            loss = rpn_loss + frcnn_loss

            rpn_classification_losses.append(rpn_output['rpn_class_loss'].item())
            rpn_localization_losses.append(rpn_output['rpn_local_loss'].item())
            frcnn_classification_losses.append(frcnn_output['frcnn_class_loss'].item())
            frcnn_localization_losses.append(frcnn_output['frcnn_local_loss'].item())
            loss = loss / acc_steps
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            step_count += 1

            '''# Evaluate MAP
            model.eval()
            rpn_output, frcnn_output = model(im, None)
            boxes = frcnn_output['boxes']
            labels = frcnn_output['labels']
            scores = frcnn_output['scores']

            pred_boxes = {}
            gt_boxes = {}
            for label_name in custom_dataset.label2idx:
                pred_boxes[label_name] = []
                gt_boxes[label_name] = []

            for idx, box in enumerate(boxes):
                x1, y1, x2, y2 = box.detach().cpu().numpy()
                label = labels[idx].detach().cpu().item()
                score = scores[idx].detach().cpu().item()
                label_name = custom_dataset.idx2label[label]
                pred_boxes[label_name].append([x1, y1, x2, y2, score])
            for idx, box in enumerate(target_boxes):
                x1, y1, x2, y2 = box.detach().cpu().numpy()
                label = target_labels[idx].detach().cpu().item()
                label_name = custom_dataset.idx2label[label]
                gt_boxes[label_name].append([x1, y1, x2, y2])

            gts.append(gt_boxes)
            preds.append(pred_boxes)
            model.train()
        mean_ap, all_aps = compute_map(preds, gts)
        print('Class Wise Average Precisions')
        for idx in range(len(custom_dataset.idx2label)):
            print('AP for class {} = {:.4f}'.format(custom_dataset.idx2label[idx], all_aps[custom_dataset.idx2label[idx]]))
        print('Mean Average Precision : {:.4f}'.format(mean_ap))'''

        print('Finished epoch {}'.format(i))
        optimizer.step()
        optimizer.zero_grad()

        # Calculate average losses for the epoch
        avg_rpn_classification_loss = np.mean(rpn_classification_losses)
        avg_rpn_localization_loss = np.mean(rpn_localization_losses)
        avg_frcnn_classification_loss = np.mean(frcnn_classification_losses)
        avg_frcnn_localization_loss = np.mean(frcnn_localization_losses)

        # Store the average losses
        rpn_classification_loss_history.append(avg_rpn_classification_loss)
        rpn_localization_loss_history.append(avg_rpn_localization_loss)
        frcnn_classification_loss_history.append(avg_frcnn_classification_loss)
        frcnn_localization_loss_history.append(avg_frcnn_localization_loss)

        # STore mAP
        #mean_ap_history.append(mean_ap)
        #all_aps_history.append(all_aps)

        torch.save(model.state_dict(), os.path.join('frcnn',
                                                                'save'))
        loss_output = ''
        loss_output += 'RPN Classification Loss : {:.4f}'.format(np.mean(rpn_classification_losses))
        loss_output += ' | RPN Localization Loss : {:.4f}'.format(np.mean(rpn_localization_losses))
        loss_output += ' | FRCNN Classification Loss : {:.4f}'.format(np.mean(frcnn_classification_losses))
        loss_output += ' | FRCNN Localization Loss : {:.4f}'.format(np.mean(frcnn_localization_losses))
        print(loss_output)

    print('Done Training...')
    return rpn_classification_loss_history, rpn_localization_loss_history, frcnn_classification_loss_history, frcnn_localization_loss_history, mean_ap_history

In [None]:
faster_rcnn_model = faster_rcnn_model = FasterRCNN(num_classes=3, scales=[128, 256, 512], aspect_ratios=[0.5, 1, 2])


train_dataset = IRDataset('train', im_dir='/content/drive/MyDrive/Colab Notebooks/ir_images',
                  ann_dir='/content/drive/MyDrive/Colab Notebooks/ir_labels')

train_loader = DataLoader(train_dataset,
                            batch_size=1,
                            shuffle=True,
                            num_workers=4)

train(faster_rcnn_model, train_dataset, train_loader)

{0: 'background', 1: 'car', 2: 'person'}


100%|██████████| 9/9 [01:32<00:00, 10.30s/it]


Finished epoch 0
RPN Classification Loss : 0.4531 | RPN Localization Loss : 2.9760 | FRCNN Classification Loss : 1.0524 | FRCNN Localization Loss : 0.4221


100%|██████████| 9/9 [01:31<00:00, 10.22s/it]


Finished epoch 1
RPN Classification Loss : 0.2390 | RPN Localization Loss : 2.0839 | FRCNN Classification Loss : 0.7899 | FRCNN Localization Loss : 0.5305


100%|██████████| 9/9 [01:34<00:00, 10.53s/it]


Finished epoch 2
RPN Classification Loss : 0.3071 | RPN Localization Loss : 3.2980 | FRCNN Classification Loss : 0.6061 | FRCNN Localization Loss : 0.1762


100%|██████████| 9/9 [01:31<00:00, 10.15s/it]


Finished epoch 3
RPN Classification Loss : 0.0842 | RPN Localization Loss : 2.0585 | FRCNN Classification Loss : 0.8099 | FRCNN Localization Loss : 0.0406


100%|██████████| 9/9 [01:31<00:00, 10.17s/it]


Finished epoch 4
RPN Classification Loss : 0.2071 | RPN Localization Loss : 1.8741 | FRCNN Classification Loss : 0.4688 | FRCNN Localization Loss : 0.3072
Done Training...


([0.453055905074709,
  0.2389744319435623,
  0.3071127154980786,
  0.08423384444581138,
  0.20712732120106617],
 [2.976001309023963,
  2.083935797214508,
  3.2979558971193104,
  2.0585126479466758,
  1.8740801678763495],
 [1.0524442560142941,
  0.7899185817546418,
  0.6061433888971806,
  0.8099084297815958,
  0.46875843591988087],
 [0.4221344954914659,
  0.530455916059307,
  0.1761951064108871,
  0.040600908855493695,
  0.3072451932756748],
 [])

In [None]:
def getter_iou(det, gt):
    det_x1, det_y1, det_x2, det_y2 = det
    gt_x1, gt_y1, gt_x2, gt_y2 = gt

    x_left = max(det_x1, gt_x1)
    y_top = max(det_y1, gt_y1)
    x_right = min(det_x2, gt_x2)
    y_bottom = min(det_y2, gt_y2)

    if x_right < x_left or y_bottom < y_top:
        return 0.0

    area_intersection = (x_right - x_left) * (y_bottom - y_top)
    det_area = (det_x2 - det_x1) * (det_y2 - det_y1)
    gt_area = (gt_x2 - gt_x1) * (gt_y2 - gt_y1)
    area_union = float(det_area + gt_area - area_intersection + 1E-6)
    iou = area_intersection / area_union
    return iou

def compute_map(det_boxes, gt_boxes, iou_threshold=0.7, method='area'):

    gt_labels = {cls_key for im_gt in gt_boxes for cls_key in im_gt.keys()}
    gt_labels = sorted(gt_labels)
    all_aps = {}
    # average precisions for ALL classes
    aps = []
    for idx, label in enumerate(gt_labels):
        # Get detection predictions of this class
        cls_dets = [
            [im_idx, im_dets_label] for im_idx, im_dets in enumerate(det_boxes)
            if label in im_dets for im_dets_label in im_dets[label]
        ]

        # Sort them by confidence score
        cls_dets = sorted(cls_dets, key=lambda k: -k[1][-1])

        # For tracking which gt boxes of this class have already been matched
        gt_matched = [[False for _ in im_gts[label]] for im_gts in gt_boxes]
        # Number of gt boxes for this class for recall calculation
        num_gts = sum([len(im_gts[label]) for im_gts in gt_boxes])
        tp = [0] * len(cls_dets)
        fp = [0] * len(cls_dets)

        # For each prediction
        for det_idx, (im_idx, det_pred) in enumerate(cls_dets):
            # Get gt boxes for this image and this label
            im_gts = gt_boxes[im_idx][label]
            max_iou_found = -1
            max_iou_gt_idx = -1

            # Get best matching gt box
            for gt_box_idx, gt_box in enumerate(im_gts):
                gt_box_iou = getter_iou(det_pred[:-1], gt_box)
                if gt_box_iou > max_iou_found:
                    max_iou_found = gt_box_iou
                    max_iou_gt_idx = gt_box_idx
            # TP only if iou >= threshold and this gt has not yet been matched
            if max_iou_found < iou_threshold or gt_matched[im_idx][max_iou_gt_idx]:
                fp[det_idx] = 1
            else:
                tp[det_idx] = 1
                # If tp then we set this gt box as matched
                gt_matched[im_idx][max_iou_gt_idx] = True
        # Cumulative tp and fp
        tp = np.cumsum(tp)
        fp = np.cumsum(fp)

        eps = np.finfo(np.float32).eps
        recalls = tp / np.maximum(num_gts, eps)
        precisions = tp / np.maximum((tp + fp), eps)

        if method == 'area':
            recalls = np.concatenate(([0.0], recalls, [1.0]))
            precisions = np.concatenate(([0.0], precisions, [0.0]))

            for i in range(precisions.size - 1, 0, -1):
                precisions[i - 1] = np.maximum(precisions[i - 1], precisions[i])
            # For computing area, get points where recall changes value
            i = np.where(recalls[1:] != recalls[:-1])[0]
            # Add the rectangular areas to get ap
            ap = np.sum((recalls[i + 1] - recalls[i]) * precisions[i + 1])
        elif method == 'interp':
            ap = 0.0
            for interp_pt in np.arange(0, 1 + 1E-3, 0.1):
                # Get precision values for recall values >= interp_pt
                prec_interp_pt = precisions[recalls >= interp_pt]

                # Get max of those precision values
                prec_interp_pt = prec_interp_pt.max() if prec_interp_pt.size > 0.0 else 0.0
                ap += prec_interp_pt
            ap = ap / 11.0
        else:
            raise ValueError('Method can only be area or interp')
        if num_gts > 0:
            aps.append(ap)
            all_aps[label] = ap
        else:
            all_aps[label] = np.nan
    # compute mAP at provided iou threshold
    mean_ap = sum(aps) / len(aps)
    return mean_ap, all_aps

In [None]:
def infer(model, custom_dataset):
    model.eval()
    if not os.path.exists('samples'):
        os.mkdir('samples')

    faster_rcnn_model.roi_head.low_score_threshold = 0.1

    for sample_count in tqdm(range(1)):
        random_idx = random.randint(0, len(custom_dataset))
        im, target, fname = custom_dataset[random_idx]
        im = im.unsqueeze(0).float().to(device)

        gt_im = cv2.imread(fname)
        gt_im_copy = gt_im.copy()

        # Saving images with ground truth boxes
        for idx, box in enumerate(target['bboxes']):
            x1, y1, x2, y2 = box.detach().cpu().numpy()
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)

            cv2.rectangle(gt_im, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0])
            cv2.rectangle(gt_im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 255, 0])
            text = custom_dataset.idx2label[target['labels'][idx].detach().cpu().item()]
            text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1)
            text_w, text_h = text_size
            cv2.rectangle(gt_im_copy , (x1, y1), (x1 + 10+text_w, y1 + 10+text_h), [255, 255, 255], -1)
            cv2.putText(gt_im, text=custom_dataset.idx2label[target['labels'][idx].detach().cpu().item()],
                        org=(x1+5, y1+15),
                        thickness=1,
                        fontScale=1,
                        color=[0, 0, 0],
                        fontFace=cv2.FONT_HERSHEY_PLAIN)
            cv2.putText(gt_im_copy, text=text,
                        org=(x1 + 5, y1 + 15),
                        thickness=1,
                        fontScale=1,
                        color=[0, 0, 0],
                        fontFace=cv2.FONT_HERSHEY_PLAIN)
        cv2.addWeighted(gt_im_copy, 0.7, gt_im, 0.3, 0, gt_im)
        cv2.imwrite('samples/output_frcnn_gt_{}.png'.format(sample_count), gt_im)

        # Getting predictions from trained model
        rpn_output, frcnn_output = model(im, None)
        boxes = frcnn_output['boxes']
        labels = frcnn_output['labels']
        scores = frcnn_output['scores']
        im = cv2.imread(fname)
        im_copy = im.copy()

        # Saving images with predicted boxes
        for idx, box in enumerate(boxes):
            x1, y1, x2, y2 = box.detach().cpu().numpy()
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
            cv2.rectangle(im, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255])
            cv2.rectangle(im_copy, (x1, y1), (x2, y2), thickness=2, color=[0, 0, 255])
            text = '{} : {:.2f}'.format(custom_dataset.idx2label[labels[idx].detach().cpu().item()],
                                        scores[idx].detach().cpu().item())
            text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_PLAIN, 1, 1)
            text_w, text_h = text_size
            cv2.rectangle(im_copy , (x1, y1), (x1 + 10+text_w, y1 + 10+text_h), [255, 255, 255], -1)
            cv2.putText(im, text=text,
                        org=(x1+5, y1+15),
                        thickness=1,
                        fontScale=1,
                        color=[0, 0, 0],
                        fontFace=cv2.FONT_HERSHEY_PLAIN)
            cv2.putText(im_copy, text=text,
                        org=(x1 + 5, y1 + 15),
                        thickness=1,
                        fontScale=1,
                        color=[0, 0, 0],
                        fontFace=cv2.FONT_HERSHEY_PLAIN)
        cv2.addWeighted(im_copy, 0.7, im, 0.3, 0, im)
        cv2.imwrite('samples/output_frcnn_{}.jpg'.format(sample_count), im)


In [None]:
infer(faster_rcnn_model, train_dataset)