<a href="https://colab.research.google.com/github/mralamdari/Computer-Vision-Papers/blob/main/Faster_RCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import cv2
import glob
import tqdm
import torch
import matplotlib
import torchvision
import numpy as np
import pandas as pd
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from sklearn import model_selection

%load_ext autoreload
%autoreload 2

tf.__version__

'2.12.0'

##Data

In [None]:
os.environ['KAGGLE_CONFIG_DIR'] = '/content/drive/MyDrive'
!kaggle datasets download -d andrewmvd/car-plate-detection
!unzip \*.zip && rm *.zip

In [3]:
IMAGE_PATH = '/content/data/images/'
ANNOTATION_PATH =  '/content/data/annotations/'

os.makedirs('/content/data/', exist_ok=True)
os.replace('/content/images', '/content/data/images')
os.replace('/content/annotations', '/content/data/annotations')

In [4]:
def parse_annotation(data_dir, img_size):

  img_paths  = []
  gdt_bboxes = []
  gdt_classes= []
  img_w, img_h = img_size

  for img_name in os.listdir(data_dir+'images'):

    img_path = os.path.join(data_dir, 'images', img_name)
    annotation_path = os.path.join(data_dir, 'annotations', img_name[:-3]+'xml')

    with open(annotation_path, 'r') as f:
      tree = ET.parse(f)

    root = tree.getroot()
    
    img_paths.append(img_path)
    ann_size = root.find('size')
    orig_w = int(ann_size.find('width').text)
    orig_h = int(ann_size.find('height').text)
    ground_truth_bboxes = []
    ground_truth_classes = []
    
    for box in root.findall('object'):
      box_root = box.find('bndbox')
      xmin = float(box_root.find('xmin').text) * img_w / orig_w
      ymin = float(box_root.find('ymin').text) * img_h / orig_h
      xmax = float(box_root.find('xmax').text) * img_w / orig_w
      ymax = float(box_root.find('ymax').text) * img_h / orig_h
      bbox = torch.Tensor([int(xmin), int(ymin), int(xmax), int(ymax)])

      ground_truth_bboxes.append(bbox.tolist())
      ground_truth_classes.append(int(root.find('segmented').text))

    gdt_bboxes.append(torch.Tensor(ground_truth_bboxes))
    gdt_classes.append(torch.Tensor(ground_truth_classes))
    
  return gdt_bboxes, gdt_classes, img_paths

In [5]:
class ObjectDetectionDataset(torch.utils.data.Dataset):
    '''
    A Pytorch Dataset class to load the images and their corresponding annotations.
    
    Returns
    ------------
    images: torch.Tensor of size (B, C, H, W)
    gt bboxes: torch.Tensor of size (B, max_objects, 4)
    gt classes: torch.Tensor of size (B, max_objects)
    '''
    def __init__(self, data_dir, img_size):
        self.data_dir = data_dir
        self.img_size = img_size
        
        self.img_data_all, self.gdt_bboxes, self.gdt_classes = self.get_data()
        
    def __len__(self):
        return self.img_data_all.size(dim=0)
    
    def __getitem__(self, idx):
        return self.img_data_all[idx], self.gdt_bboxes[idx], self.gdt_classes[idx]
        
    def get_data(self):
        img_data = []
        gdt_idxs = []

        gdt_boxes, gdt_classes, img_paths = parse_annotation(self.data_dir, self.img_size)

        for i, img_path in enumerate(img_paths):
            # skip if the image path is not valid
            if (not img_path) or (not os.path.exists(img_path)):
                continue
                
            # read and resize image
            img = cv2.imread(img_path)
            img = cv2.resize(img, self.img_size)
            # convert image to torch tensor and reshape it so channels come first
            img_tensor = torch.from_numpy(img).permute(2, 0, 1)
            
            # encode class names as integers
            gdt_idx = gdt_classes[i]
            
            img_data.append(img_tensor)
            gdt_idxs.append(gdt_idx)
        # pad bounding boxes and classes so they are of the same size
        gt_bboxes_pad = torch.nn.utils.rnn.pad_sequence(gdt_boxes, batch_first=True, padding_value=-1)
        gt_classes_pad = torch.nn.utils.rnn.pad_sequence(gdt_idxs, batch_first=True, padding_value=-1)
        
        # stack all images
        img_data_stacked = torch.stack(img_data, dim=0)
        
        return img_data_stacked.to(dtype=torch.float32), gt_bboxes_pad, gt_classes_pad

##Preprocessing Utils


Generate Anchor Points

In [18]:
def gen_anc_centers(out_size):
  
    out_h, out_w = out_size
    anc_pts_x = torch.arange(0, out_w) + 0.5
    anc_pts_y = torch.arange(0, out_h) + 0.5
    
    return anc_pts_x, anc_pts_y

Generate Anchor Boxes

In [19]:
def gen_anc_boxes(anc_pts_x, anc_pts_y, anc_scales, anc_ratios, out_size):
    n_anc_boxes = len(anc_scales) * len(anc_ratios)
    anc_base = torch.zeros(1, anc_pts_x.size(dim=0) , anc_pts_y.size(dim=0), n_anc_boxes, 4) # shape - [1, Hmap, Wmap, n_anchor_boxes, 4]
    
    for ix, x_center in enumerate(anc_pts_x):
        for jx, y_center in enumerate(anc_pts_y):
            anc_boxes = torch.zeros((n_anc_boxes, 4))
            c = 0
            for i, scale in enumerate(anc_scales):
                for j, ratio in enumerate(anc_ratios):
                    w = scale * ratio
                    h = scale
                    
                    xmin = x_center - w / 2
                    ymin = y_center - h / 2
                    xmax = x_center + w / 2
                    ymax = y_center + h / 2

                    anc_boxes[c, :] = torch.Tensor([xmin, ymin, xmax, ymax])
                    c += 1

            anc_base[:, ix, jx, :] = torchvision.ops.clip_boxes_to_image(anc_boxes, size=out_size)
            
    return anc_base

Positive and Negative Anchor Boxes


In [20]:
def get_iou_mat(batch_size, anc_boxes_all, gdt_bboxes_all):
    
    # flatten anchor boxes
    anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4)

    # create a placeholder to compute IoUs amongst the boxes
    ious_mat = torch.zeros((batch_size, anc_boxes_flat.size(dim=1), gdt_bboxes_all.size(dim=1)))

    # compute IoU of the anc boxes with the gt boxes for all the images
    for i in range(batch_size):
        gt_bboxes = gdt_bboxes_all[i]
        anc_boxes = anc_boxes_flat[i]
        ious_mat[i, :] = torchvision.ops.box_iou(anc_boxes, gt_bboxes)
        
    return ious_mat

Projecting Boxes


In [21]:
def project_bboxes(bboxes, width_scale_factor, height_scale_factor, mode='a2p'):
    assert mode in ['a2p', 'p2a']
    
    batch_size = bboxes.size(dim=0)
    proj_bboxes = bboxes.clone().reshape(batch_size, -1, 4)
    invalid_bbox_mask = (proj_bboxes == -1) # indicating padded bboxes
    
    if mode == 'a2p':
        # activation map to pixel image
        proj_bboxes[:, :, [0, 2]] *= width_scale_factor  #xmin, xmax
        proj_bboxes[:, :, [1, 3]] *= height_scale_factor #ymin, ymax
    else:
        # pixel image to activation map
        proj_bboxes[:, :, [0, 2]] /= width_scale_factor
        proj_bboxes[:, :, [1, 3]] /= height_scale_factor
        
    proj_bboxes.masked_fill_(invalid_bbox_mask, -1) # fill padded bboxes back with -1
    proj_bboxes.resize_as_(bboxes)
    
    return proj_bboxes

Computing Offsets


In [22]:
def calc_gt_offsets(pos_anc_coords, gtd_bbox_mapping):
    pos_anc_coords = torchvision.ops.box_convert(pos_anc_coords, in_fmt='xyxy', out_fmt='cxcywh')
    gtd_bbox_mapping = torchvision.ops.box_convert(gtd_bbox_mapping, in_fmt='xyxy', out_fmt='cxcywh')

    gt_cx, gt_cy, gt_w, gt_h = gtd_bbox_mapping[:, 0], gtd_bbox_mapping[:, 1], gtd_bbox_mapping[:, 2], gtd_bbox_mapping[:, 3]
    anc_cx, anc_cy, anc_w, anc_h = pos_anc_coords[:, 0], pos_anc_coords[:, 1], pos_anc_coords[:, 2], pos_anc_coords[:, 3]

    tx_ = (gt_cx - anc_cx)/anc_w
    ty_ = (gt_cy - anc_cy)/anc_h
    tw_ = torch.log(gt_w / anc_w)
    th_ = torch.log(gt_h / anc_h)

    return torch.stack([tx_, ty_, tw_, th_], dim=-1)

Create Pos/Neg Anchors

In [75]:
def get_req_anchors(anc_boxes_all, gt_bboxes_all, gt_classes_all, pos_thresh=0.7, neg_thresh=0.2):
    '''
    Prepare necessary data required for training
    
    Input
    ------
    anc_boxes_all - torch.Tensor of shape (B, w_amap, h_amap, n_anchor_boxes, 4)
        all anchor boxes for a batch of images
    gt_bboxes_all - torch.Tensor of shape (B, max_objects, 4)
        padded ground truth boxes for a batch of images
    gt_classes_all - torch.Tensor of shape (B, max_objects)
        padded ground truth classes for a batch of images
        
    Returns
    ---------
    positive_anc_ind -  torch.Tensor of shape (n_pos,)
        flattened positive indices for all the images in the batch
    negative_anc_ind - torch.Tensor of shape (n_pos,)
        flattened positive indices for all the images in the batch
    GT_conf_scores - torch.Tensor of shape (n_pos,), IoU scores of +ve anchors
    GT_offsets -  torch.Tensor of shape (n_pos, 4),
        offsets between +ve anchors and their corresponding ground truth boxes
    GT_class_pos - torch.Tensor of shape (n_pos,)
        mapped classes of +ve anchors
    positive_anc_coords - (n_pos, 4) coords of +ve anchors (for visualization)
    negative_anc_coords - (n_pos, 4) coords of -ve anchors (for visualization)
    positive_anc_ind_sep - list of indices to keep track of +ve anchors
    '''
    # get the size and shape parameters
    B, w_amap, h_amap, A, _ = anc_boxes_all.shape
    N = gt_bboxes_all.shape[1] # max number of groundtruth bboxes in a batch
    
    # get total number of anchor boxes in a single image
    tot_anc_boxes = A * w_amap * h_amap
    
    # get the iou matrix which contains iou of every anchor box
    # against all the groundtruth bboxes in an image
    iou_mat = get_iou_mat(B, anc_boxes_all, gt_bboxes_all)
    
    # for every groundtruth bbox in an image, find the iou 
    # with the anchor box which it overlaps the most
    max_iou_per_gt_box, _ = iou_mat.max(dim=1, keepdim=True)
    
    # get positive anchor boxes
    
    # condition 1: the anchor box with the max iou for every gt bbox
    positive_anc_mask = torch.logical_and(iou_mat == max_iou_per_gt_box, max_iou_per_gt_box > 0) 
    # condition 2: anchor boxes with iou above a threshold with any of the gt bboxes
    positive_anc_mask = torch.logical_or(positive_anc_mask, iou_mat > pos_thresh)
    
    positive_anc_ind_sep = torch.where(positive_anc_mask)[0] # get separate indices in the batch
    # combine all the batches and get the idxs of the +ve anchor boxes
    positive_anc_mask = positive_anc_mask.flatten(start_dim=0, end_dim=1)
    positive_anc_ind = torch.where(positive_anc_mask)[0]
    
    # for every anchor box, get the iou and the idx of the
    # gt bbox it overlaps with the most
    max_iou_per_anc, max_iou_per_anc_ind = iou_mat.max(dim=-1)
    max_iou_per_anc = max_iou_per_anc.flatten(start_dim=0, end_dim=1)
    
    # get iou scores of the +ve anchor boxes
    GT_conf_scores = max_iou_per_anc[positive_anc_ind]
    
    # get gt classes of the +ve anchor boxes
    
    # expand gt classes to map against every anchor box
    gt_classes_expand = gt_classes_all.view(B, 1, N).expand(B, tot_anc_boxes, N)
    # for every anchor box, consider only the class of the gt bbox it overlaps with the most
    GT_class = torch.gather(gt_classes_expand, -1, max_iou_per_anc_ind.unsqueeze(-1)).squeeze(-1)
    # combine all the batches and get the mapped classes of the +ve anchor boxes
    GT_class = GT_class.flatten(start_dim=0, end_dim=1)
    GT_class_pos = GT_class[positive_anc_ind]
    
    # get gt bbox coordinates of the +ve anchor boxes
    
    # expand all the gt bboxes to map against every anchor box
    gt_bboxes_expand = gt_bboxes_all.view(B, 1, N, 4).expand(B, tot_anc_boxes, N, 4)
    # for every anchor box, consider only the coordinates of the gt bbox it overlaps with the most
    GT_bboxes = torch.gather(gt_bboxes_expand, -2, max_iou_per_anc_ind.reshape(B, tot_anc_boxes, 1, 1).repeat(1, 1, 1, 4))
    # combine all the batches and get the mapped gt bbox coordinates of the +ve anchor boxes
    GT_bboxes = GT_bboxes.flatten(start_dim=0, end_dim=2)
    GT_bboxes_pos = GT_bboxes[positive_anc_ind]
    
    # get coordinates of +ve anc boxes
    anc_boxes_flat = anc_boxes_all.flatten(start_dim=0, end_dim=-2) # flatten all the anchor boxes
    positive_anc_coords = anc_boxes_flat[positive_anc_ind]
    
    # calculate gt offsets
    GT_offsets = calc_gt_offsets(positive_anc_coords, GT_bboxes_pos)
    
    # get -ve anchors
    
    # condition: select the anchor boxes with max iou less than the threshold
    negative_anc_mask = (max_iou_per_anc < neg_thresh)
    negative_anc_ind = torch.where(negative_anc_mask)[0]
    # sample -ve samples to match the +ve samples
    negative_anc_ind = negative_anc_ind[torch.randint(0, negative_anc_ind.shape[0], (positive_anc_ind.shape[0],))]
    negative_anc_coords = anc_boxes_flat[negative_anc_ind]
    
    return positive_anc_ind, negative_anc_ind, GT_conf_scores, GT_offsets, GT_class_pos, \
         positive_anc_coords, negative_anc_coords, positive_anc_ind_sep


In [58]:
def generate_proposals(anchors, offsets):
   
    # change format of the anchor boxes from 'xyxy' to 'cxcywh'
    anchors = torchvision.ops.box_convert(anchors, in_fmt='xyxy', out_fmt='cxcywh')

    # apply offsets to anchors to create proposals
    proposals_ = torch.zeros_like(anchors)
    proposals_[:,0] = anchors[:,0] + offsets[:,0]*anchors[:,2]
    proposals_[:,1] = anchors[:,1] + offsets[:,1]*anchors[:,3]
    proposals_[:,2] = anchors[:,2] * torch.exp(offsets[:,2])
    proposals_[:,3] = anchors[:,3] * torch.exp(offsets[:,3])

    # change format of proposals back from 'cxcywh' to 'xyxy'
    proposals = torchvision.ops.box_convert(proposals_, in_fmt='cxcywh', out_fmt='xyxy')

    return proposals

Visualization Utils

In [59]:
def display_img(img_data, fig, axes):
    for i, img in enumerate(img_data):
        if type(img) == torch.Tensor:
            img = img.permute(1, 2, 0).numpy()
        axes[i].imshow(np.int64(img))
    
    return fig, axes

In [60]:
def display_bbox(bboxes, fig, ax, classes=None, in_format='xyxy', color='y', line_width=3):
    if type(bboxes) == np.ndarray:
        bboxes = torch.from_numpy(bboxes)
    if classes:
        assert len(bboxes) == len(classes)
    # convert boxes to xywh format
    bboxes = torchvision.ops.box_convert(bboxes, in_fmt=in_format, out_fmt='xywh')
    c = 0
    for box in bboxes:
        x, y, w, h = box.numpy()
        # display bounding box
        rect = matplotlib.patches.Rectangle((x, y), w, h, linewidth=line_width, edgecolor=color, facecolor='none')
        ax.add_patch(rect)
        # display category
        if classes:
            if classes[c] == 'pad':
                continue
            ax.text(x + 5, y + 20, classes[c], bbox=dict(facecolor='yellow', alpha=0.5))
        c += 1
        
    return fig, ax

In [61]:
def display_grid(x_points, y_points, fig, ax, special_point=None):
    # plot grid
    for x in x_points:
        for y in y_points:
            ax.scatter(x, y, color="w", marker='+')
            
    # plot a special point we want to emphasize on the grid
    if special_point:
        x, y = special_point
        ax.scatter(x, y, color="red", marker='+')
        
    return fig, ax

Backbone

In [62]:
class FeatureExtractor(torch.nn.Module):
  def __init__(self):
    super().__init__()
    model = torchvision.models.resnet50(pretrained=True)
    req_layers = list(model.children())[:8] #Ignore AdaptiveAvgPool, Linear classifier Layer
    self.backbone = torch.nn.Sequential(*req_layers)
    for param in self.backbone.named_parameters():
      param[1].requres_grad = True

  def forward(self, img_data):
    return self.backbone(img_data)

In [63]:
class ProposalModule(torch.nn.Module):
  
  def __init__(self, in_features, hidden_dim=512, n_anchors=9, p_dropout=0.3):
    super().__init__()
    self.n_anchors = n_anchors
    self.conv1 = torch.nn.Conv2d(in_features, hidden_dim, kernel_size=3, padding=1)
    self.droput= torch.nn.Dropout(p_dropout)
    self.conf_head = torch.nn.Conv2d(hidden_dim, n_anchors, kernel_size=1)
    self.reg_head = torch.nn.Conv2d(hidden_dim, n_anchors*4, kernel_size=1)
  
  def forward(self, feature_map, pos_anc_ind=None, neg_anc_ind=None, pos_anc_coords=None):
    if pos_anc_ind is None or neg_anc_ind is None or pos_anc_coords is None:
      mode = 'eval'
    else:
      mode = 'train'

    out = self.conv1(feature_map)
    out = self.droput(out)
    out = torch.nn.functional.relu(out)

    reg_offsets_pred = self.reg_head(out)  # (B, A*4, hmap, wmap)
    conf_scores_pred = self.conf_head(out) # (B, A, hmap, wmap)

    if mode=='train':
      #get confidence scores
      conf_scrors_pos = conf_scores_pred.flatten()[pos_anc_ind]
      conf_scrors_neg = conf_scores_pred.flatten()[neg_anc_ind]
      
      #get offsets for positive anchors
      offsets_pos = reg_offsets_pred.contiguous().view(-1, 4)[pos_anc_ind]
      #generate proposals using offsets
      proposals   = generate_proposals(pos_anc_coords, offsets_pos)

      return conf_scrors_pos, conf_scrors_neg, offsets_pos, proposals
    else:
      return conf_scores_pred, reg_offsets_pred

##stage 1 of the detector
##RPN

In [64]:
class RegionProposalNetwork(torch.nn.Module):
    def __init__(self, img_size):
        super().__init__()
        
        self.img_height, self.img_width = img_size

        # scales and ratios for anchor boxes
        self.anc_scales = [2, 4, 6]
        self.anc_ratios = [0.5, 1, 1.5]
        self.n_anc_boxes = len(self.anc_scales) * len(self.anc_ratios)
        
        # IoU thresholds for +ve and -ve anchors
        self.pos_thresh = 0.7
        self.neg_thresh = 0.3
        
        # weights for loss
        self.w_conf = 1
        self.w_reg = 5
        
        self.feature_extractor = FeatureExtractor() #feature_map 
        
    def forward(self, images, gt_bboxes, gt_classes):
      
        batch_size = images.size(dim=0)
        feature_map = self.feature_extractor(images)

        out_h, out_w, out_c = feature_map.size(dim=1), feature_map.size(dim=2), feature_map.size(dim=3)

        # downsampling scale factor 
        width_scale_factor = self.img_width // out_w
        height_scale_factor = self.img_height // out_h 
        
        
        # generate anchors
        anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(out_h, out_w))
        anc_base = gen_anc_boxes(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (out_h, out_w))
        anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1)
        
        # get positive and negative anchors amongst other things
        gt_bboxes_proj = project_bboxes(gt_bboxes, width_scale_factor, height_scale_factor, mode='p2a')
        
        positive_anc_ind, negative_anc_ind, GT_conf_scores, \
        GT_offsets, GT_class_pos, positive_anc_coords, \
        negative_anc_coords, positive_anc_ind_sep = get_req_anchors(anc_boxes_all, gt_bboxes_proj, gt_classes)
        
        # pass through the proposal module
        proposal_module = ProposalModule(out_c, n_anchors=self.n_anc_boxes)
        conf_scores_pos, conf_scores_neg, offsets_pos, proposals = proposal_module(feature_map, positive_anc_ind, negative_anc_ind, positive_anc_coords)
        
        cls_loss = calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size)
        reg_loss = calc_bbox_reg_loss(GT_offsets, offsets_pos, batch_size)
        
        total_rpn_loss = self.w_conf * cls_loss + self.w_reg * reg_loss
        
        return total_rpn_loss, feature_map, proposals, positive_anc_ind_sep, GT_class_pos, [out_w, out_h, out_c]
    
    def inference(self, images, conf_thresh=0.5, nms_thresh=0.7):
        with torch.no_grad():

            batch_size = images.size(dim=0)
            feature_map = self.feature_extractor(images)
            out_h, out_w, out_c = feature_map.size(dim=1), feature_map.size(dim=2), feature_map.size(dim=3)

            # downsampling scale factor 
            width_scale_factor = self.img_width // out_w
            height_scale_factor = self.img_height // out_h 
            
            # generate anchors
            anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(out_h, out_w))
            anc_base = gen_anc_boxes(anc_pts_x, anc_pts_y, self.anc_scales, self.anc_ratios, (out_h, out_w))
            anc_boxes_all = anc_base.repeat(batch_size, 1, 1, 1, 1)
            anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4)

            # get conf scores and offsets
            proposal_module = ProposalModule(out_c, n_anchors=self.n_anc_boxes)
            conf_scores_pred, offsets_pred = proposal_module(feature_map)
            conf_scores_pred = conf_scores_pred.reshape(batch_size, -1)
            offsets_pred = offsets_pred.reshape(batch_size, -1, 4)

            # filter out proposals based on conf threshold and nms threshold for each image
            proposals_final = []
            conf_scores_final = []
            for i in range(batch_size):
                conf_scores = torch.sigmoid(conf_scores_pred[i])
                offsets = offsets_pred[i]
                anc_boxes = anc_boxes_flat[i]
                proposals = generate_proposals(anc_boxes, offsets)
                # filter based on confidence threshold
                conf_idx = torch.where(conf_scores >= conf_thresh)[0]
                conf_scores_pos = conf_scores[conf_idx]
                proposals_pos = proposals[conf_idx]
                # filter based on nms threshold
                nms_idx = torchvision.ops.nms(proposals_pos, conf_scores_pos, nms_thresh)
                conf_scores_pos = conf_scores_pos[nms_idx]
                proposals_pos = proposals_pos[nms_idx]

                proposals_final.append(proposals_pos)
                conf_scores_final.append(conf_scores_pos)
            
        return proposals_final, conf_scores_final, feature_map, [out_w, out_h, out_c]

Loss Utils

In [65]:
def calc_cls_loss(conf_scores_pos, conf_scores_neg, batch_size):
    target_pos = torch.ones_like(conf_scores_pos)
    target_neg = torch.zeros_like(conf_scores_neg)
    
    target = torch.cat((target_pos, target_neg))
    inputs = torch.cat((conf_scores_pos, conf_scores_neg))
     
    loss = torch.nn.functional.binary_cross_entropy_with_logits(inputs, target, reduction='sum') * 1. / batch_size
    return loss

In [66]:
def calc_bbox_reg_loss(gt_offsets, reg_offsets_pos, batch_size):
    assert gt_offsets.size() == reg_offsets_pos.size()
    loss = torch.nn.functional.smooth_l1_loss(reg_offsets_pos, gt_offsets, reduction='sum') * 1. / batch_size
    return loss

Stage 2

In [67]:
class ClassificationModule(torch.nn.Module):
    def __init__(self, out_channels, n_classes, roi_size, hidden_dim=512, p_dropout=0.3):
        super().__init__()        
        self.roi_size = roi_size
        # hidden network
        self.avg_pool = torch.nn.AvgPool2d(self.roi_size)
        self.fc = torch.nn.Linear(out_channels, hidden_dim)
        self.dropout = torch.nn.Dropout(p_dropout)
        
        # define classification head
        self.cls_head = torch.nn.Linear(hidden_dim, n_classes)
        
    def forward(self, feature_map, proposals_list, gt_classes=None):
        
        if gt_classes is None:
            mode = 'eval'
        else:
            mode = 'train'
        
        # apply roi pooling on proposals followed by avg pooling
        roi_out = torchvision.ops.roi_pool(feature_map, proposals_list, self.roi_size)
        roi_out = self.avg_pool(roi_out)
        
        # flatten the output
        roi_out = roi_out.squeeze(-1).squeeze(-1)
        
        # pass the output through the hidden network
        out = self.fc(roi_out)
        out = torch.nn.functional.relu(self.dropout(out))
        
        # get the classification scores
        cls_scores = self.cls_head(out)
        
        if mode == 'eval':
            return cls_scores
        
        # compute cross entropy loss
        cls_loss = torch.nn.functional.cross_entropy(cls_scores, gt_classes.long())
        
        return cls_loss

wrap it up

In [68]:
class TwoStageDetector(torch.nn.Module):
    def __init__(self, img_size, n_classes, roi_size):
        super().__init__() 
        self.rpn = RegionProposalNetwork(img_size)
        
    def forward(self, images, gt_bboxes, gt_classes):
        total_rpn_loss, feature_map, proposals, \
        positive_anc_ind_sep, GT_class_pos, out_size = self.rpn(images, gt_bboxes, gt_classes)
        out_w, out_h, out_c = out_size
        
        # get separate proposals for each sample
        pos_proposals_list = []
        batch_size = images.size(dim=0)
        for idx in range(batch_size):
            proposal_idxs = torch.where(positive_anc_ind_sep == idx)[0]
            proposals_sep = proposals[proposal_idxs].detach().clone()
            pos_proposals_list.append(proposals_sep)
        
        classifier = ClassificationModule(out_c, n_classes, roi_size)
        cls_loss = classifier(feature_map, pos_proposals_list, GT_class_pos)
        total_loss = cls_loss + total_rpn_loss
        
        return total_loss
    
    def inference(self, images, conf_thresh=0.5, nms_thresh=0.7):
        batch_size = images.size(dim=0)
        proposals_final, conf_scores_final, feature_map, out_size = self.rpn.inference(images, conf_thresh, nms_thresh)
        out_w, out_h, out_c = out_size
        classifier = ClassificationModule(out_c, n_classes, roi_size)
        cls_scores = classifier(feature_map, proposals_final)
        
        # convert scores into probability
        cls_probs = torch.nn.functional.softmax(cls_scores, dim=-1)
        # get classes with highest probability
        classes_all = torch.argmax(cls_probs, dim=-1)
        
        classes_final = []
        # slice classes to map to their corresponding image
        c = 0
        for i in range(batch_size):
            n_proposals = len(proposals_final[i]) # get the number of proposals for each image
            classes_final.append(classes_all[c: c+n_proposals])
            c += n_proposals
            
        return proposals_final, conf_scores_final, classes_final

##Run

In [69]:
img_width = 640
img_height = 480
data_dir = '/content/data/'
# name2idx = {'pad': -1, 'camel': 0, 'bird': 1}
name2idx = {'pad': -1, 'license': 0}
idx2name = {v:k for k, v in name2idx.items()}

In [70]:
od_dataset = ObjectDetectionDataset(data_dir, (img_height, img_width))
od_dataloader = torch.utils.data.DataLoader(od_dataset, batch_size=2)

## Train

In [71]:
# run the image through the backbone
img_size = (img_height, img_width)
n_classes = len(name2idx) - 1 # exclude pad idx
roi_size = (2, 2)

detector = TwoStageDetector(img_size, n_classes, roi_size)



In [73]:
def get_iou_mat(batch_size, anc_boxes_all, gdt_bboxes_all):
    print(22222222222)    
    # flatten anchor boxes
    anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4)

    # create a placeholder to compute IoUs amongst the boxes
    ious_mat = torch.zeros((batch_size, anc_boxes_flat.size(dim=1), gdt_bboxes_all.size(dim=1)))

    # compute IoU of the anc boxes with the gt boxes for all the images
    for i in range(batch_size):
        gt_bboxes = gdt_bboxes_all[i]
        anc_boxes = anc_boxes_flat[i]
        ious_mat[i, :] = torchvision.ops.box_iou(anc_boxes, gt_bboxes)
        
    return ious_mat

In [74]:
for img_batch, gt_bboxes_batch, gt_classes_batch in od_dataloader:
    img_data_all = img_batch
    gt_bboxes_all = gt_bboxes_batch
    gt_classes_all = gt_classes_batch
    break
    
img_data_all = img_data_all[:2]
gt_bboxes_all = gt_bboxes_all[:2]
gt_classes_all = gt_classes_all[:2]

total_loss = detector(img_batch, gt_bboxes_batch, gt_classes_batch)
proposals_final, conf_scores_final, classes_final = detector.inference(img_batch)

22222222222
tensor([[[nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.],
         ...,
         [nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.]],

        [[nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.],
         ...,
         [nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.],
         [nan, 0., 0., 0., 0., 0.]]])


RuntimeError: ignored

In [None]:
def training_loop(model, learning_rate, train_dataloader, n_epochs):
    
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    loss_list = []
    
    for i in tqdm.tqdm(range(n_epochs)):
        total_loss = 0
        for img_batch, gt_bboxes_batch, gt_classes_batch in train_dataloader:
            
            # forward pass
            loss = model(img_batch, gt_bboxes_batch, gt_classes_batch)
            
            # backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        loss_list.append(total_loss)
        
    return loss_list

In [None]:
learning_rate = 1e-3
n_epochs = 10

loss_list = training_loop(detector, learning_rate, od_dataloader, n_epochs)

In [None]:
plt.plot(loss_list)

In [None]:
torch.save(detector.state_dict(), "model.pt")

In [None]:
detector.eval()
proposals_final, conf_scores_final, classes_final = detector.inference(img_batch, conf_thresh=0.99, nms_thresh=0.05)


In [None]:
# project proposals to the image space
prop_proj_1 = project_bboxes(proposals_final[0], width_scale_factor, height_scale_factor, mode='a2p')
prop_proj_2 = project_bboxes(proposals_final[1], width_scale_factor, height_scale_factor, mode='a2p')

# get classes
classes_pred_1 = [idx2name[cls] for cls in classes_final[0].tolist()]
classes_pred_2 = [idx2name[cls] for cls in classes_final[1].tolist()]

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 8))

fig, axes = display_img(img_batch, fig, axes)
fig, _ = display_bbox(prop_proj_1, fig, axes[0], classes=classes_pred_1)
fig, _ = display_bbox(prop_proj_2, fig, axes[1], classes=classes_pred_2)

##Test part

In [None]:
for img_batch, gt_bboxes_batch, gt_classes_batch in od_dataloader:
    img_data_all = img_batch
    gt_bboxes_all = gt_bboxes_batch
    gt_classes_all = gt_classes_batch
    break
    
img_data_all = img_data_all[:2]
gt_bboxes_all = gt_bboxes_all[:2]
gt_classes_all = gt_classes_all[:2]

In [None]:
# get class names
gt_class_1 = gt_classes_all[0].long()
gt_class_1 = [idx2name[idx.item()] for idx in gt_class_1]

gt_class_2 = gt_classes_all[1].long()
gt_class_2 = [idx2name[idx.item()] for idx in gt_class_2]

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(8, 4))

fig, axes = display_img(img_data_all, fig, axes)
fig, _ = display_bbox(gt_bboxes_all[0], fig, axes[0], classes=gt_class_1)
fig, _ = display_bbox(gt_bboxes_all[1], fig, axes[1], classes=gt_class_2)

In [None]:
model = torchvision.models.resnet50(pretrained=True)
req_layers = list(model.children())[:8]
backbone = torch.nn.Sequential(*req_layers)
# unfreeze all the parameters
for param in backbone.named_parameters():
    param[1].requires_grad = True
# run the image through the backbone
out = backbone(img_data_all)
out_c, out_h, out_w = out.size(dim=1), out.size(dim=2), out.size(dim=3)

In [None]:
width_scale_factor = img_width // out_w
height_scale_factor = img_height // out_h
height_scale_factor, width_scale_factor

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(8, 4))

filters_data =[filters[0].detach().numpy() for filters in out[:2]]

fig, axes = display_img(filters_data, fig, axes)

In [None]:
anc_pts_x, anc_pts_y = gen_anc_centers(out_size=(out_h, out_w))

In [None]:
# project anchor centers onto the original image
anc_pts_x_proj = anc_pts_x.clone() * width_scale_factor 
anc_pts_y_proj = anc_pts_y.clone() * height_scale_factor

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(8, 4))
 
# project anchor centers onto the original image

fig, axes = display_img(img_data_all, fig, axes)
fig, _ = display_grid(anc_pts_x_proj, anc_pts_y_proj, fig, axes[0])
fig, _ = display_grid(anc_pts_x_proj, anc_pts_y_proj, fig, axes[1])

In [None]:
anc_scales = [2, 4, 6]
anc_ratios = [0.5, 1, 1.5]
n_anc_boxes = len(anc_scales) * len(anc_ratios) # number of anchor boxes for each anchor point

anc_base = gen_anc_boxes(anc_pts_x, anc_pts_y, anc_scales, anc_ratios, (out_h, out_w))
# since all the images are scaled to the same size
# we can repeat the anchor base for all the images
anc_boxes_all = anc_base.repeat(img_data_all.size(dim=0), 1, 1, 1, 1)

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 8))

fig, axes = display_img(img_data_all, fig, axes)

# project anchor boxes to the image
anc_boxes_proj = project_bboxes(anc_boxes_all, width_scale_factor, height_scale_factor, mode='a2p')

# plot anchor boxes around selected anchor points
sp_1 = [5, 8]
sp_2 = [12, 9]
bboxes_1 = anc_boxes_proj[0][sp_1[0], sp_1[1]]
bboxes_2 = anc_boxes_proj[1][sp_2[0], sp_2[1]]

fig, _ = display_grid(anc_pts_x_proj, anc_pts_y_proj, fig, axes[0], (anc_pts_x_proj[sp_1[0]], anc_pts_y_proj[sp_1[1]]))
fig, _ = display_grid(anc_pts_x_proj, anc_pts_y_proj, fig, axes[1], (anc_pts_x_proj[sp_2[0]], anc_pts_y_proj[sp_2[1]]))
fig, _ = display_bbox(bboxes_1, fig, axes[0])
fig, _ = display_bbox(bboxes_2, fig, axes[1])

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 8))

fig, axes = display_img(img_data_all, fig, axes)

# plot feature grid
fig, _ = display_grid(anc_pts_x_proj, anc_pts_y_proj, fig, axes[0])
fig, _ = display_grid(anc_pts_x_proj, anc_pts_y_proj, fig, axes[1])

# plot all anchor boxes
for x in range(anc_pts_x_proj.size(dim=0)):
    for y in range(anc_pts_y_proj.size(dim=0)):
        bboxes = anc_boxes_proj[0][x, y]
        fig, _ = display_bbox(bboxes, fig, axes[0], line_width=1)
        fig, _ = display_bbox(bboxes, fig, axes[1], line_width=1)

In [None]:
#Get Positive and Negative Anchors

pos_thresh = 0.7
neg_thresh = 0.3

# project gt bboxes onto the feature map
gt_bboxes_proj = project_bboxes(gt_bboxes_all, width_scale_factor, height_scale_factor, mode='p2a')
positive_anc_ind, negative_anc_ind, GT_conf_scores, \
GT_offsets, GT_class_pos, positive_anc_coords, \
negative_anc_coords, positive_anc_ind_sep = get_req_anchors(anc_boxes_all, gt_bboxes_proj, gt_classes_all, pos_thresh, neg_thresh)


In [None]:
# project anchor coords to the image space
pos_anc_proj = project_bboxes(positive_anc_coords, width_scale_factor, height_scale_factor, mode='a2p')
neg_anc_proj = project_bboxes(negative_anc_coords, width_scale_factor, height_scale_factor, mode='a2p')

# grab +ve and -ve anchors for each image separately

anc_idx_1 = torch.where(positive_anc_ind_sep == 0)[0]
anc_idx_2 = torch.where(positive_anc_ind_sep == 1)[0]

pos_anc_1 = pos_anc_proj[anc_idx_1]
pos_anc_2 = pos_anc_proj[anc_idx_2]

neg_anc_1 = neg_anc_proj[anc_idx_1]
neg_anc_2 = neg_anc_proj[anc_idx_2]

In [None]:
nrows, ncols = (1, 2)
fig, axes = plt.subplots(nrows, ncols, figsize=(16, 8))

fig, axes = display_img(img_data_all, fig, axes)

# plot groundtruth bboxes
fig, _ = display_bbox(gt_bboxes_all[0], fig, axes[0])
fig, _ = display_bbox(gt_bboxes_all[1], fig, axes[1])

# plot positive anchor boxes
fig, _ = display_bbox(pos_anc_1, fig, axes[0], color='g')
fig, _ = display_bbox(pos_anc_2, fig, axes[1], color='g')

# plot negative anchor boxes
fig, _ = display_bbox(neg_anc_1, fig, axes[0], color='r')
fig, _ = display_bbox(neg_anc_2, fig, axes[1], color='r')