# Implementation of Retina Network.

(Putting it all together)

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [7]:
from typing import List, Optional
import torch
from torch import Tensor

In [3]:
from base import BaseModel
from model.backbone.resnet import ResNet50
from model.backbone.retina_meta import RetinaNetFPN50, RetinaNetHead
from model.anchor_generator import AnchorBoxGenerator
from model.matcher import Matcher
from model.box_regression import Box2BoxTransform
from utils.box_utils import pairwise_iou

In [10]:
def permute_to_N_KWA_K(tensor: Tensor, K: int):
    """
    Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (H x W x Ai), K)
    """
    assert tensor.dim() == 4, tensor.shape
    N, _, H, W = tensor.shape
    tensor = tensor.view(N, -1, K, H, W)
    tensor = tensor.permute(0, 3, 4, 1, 2)
    tensor = tensor.reshape(N, -1, K) # (N, HWA, K)
    return tensor

In [4]:
class RetinaNetModel(BaseModel):
    def __init__(self, num_classes: Optional[int] = 80):
        super().__init__()

        sizes = [32.0, 64.0, 128.0, 256.0, 512.0]
        aspect_ratios = [0.5, 1.0, 2.0]
        scales = [1.0, 2 ** (1 / 3), 2 ** (2 / 3)]
        strides = [2, 2, 2, 2, 2]

        self.base = ResNet50()
        self.backbone = RetinaNetFPN50()
        self.head = RetinaNetHead(num_classes)
        self.anchor_generator = AnchorBoxGenerator(sizes, aspect_ratios, strides, scales)
        self.anchor_matcher = Matcher([0.4, 0.5], [0, -1, 1])
        self.box2box_transform = Box2BoxTransform([1., 1., 1., 1.])

    def forward(self, x, boxes, labels):
        gt_labels = []
        matched_gt_boxes = []
        bs = x.size(0)
        
        _, C3, C4, C5 = self.base(x)
        P3, P4, P5, P6, P7 = self.backbone(C3, C4, C5)
        
        anchors = self.anchor_generator([P3, P4, P5, P6, P7])
        pred_logits, pred_anchor_deltas = self.head(P3, P4, P5, P6, P7)
        
        gt_labels, gt_boxes = self.label_anchors(bs, anchors, boxes, labels)
        
        losses = self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes)
        return losses
    
    def self.label_anchors(self, bs, anchors, boxes, labels):
        gt_labels = []
        matched_bt_boxes = []
        for i in range(bs):
            match_quality_matrix = pairwise_iou(boxes[i], anchors)
            matched_idxs, anchor_labels = self.anchor_matcher(match_quality_matrix)
            del match_quality_matrix
            
            if len(boxes[i]) > 0:
                matched_gt_boxes_i = boxes[i][matched_idxs]
                gt_labels_i = labels[i][matched_idxs]
                gt_labels_i[anchor_labels == -1] = -1
                
            else:
                matched_gt_boxes_i = torch.zeros_like(anchors)
                gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes
                
            gt_labels.append(gt_labels_i)
            matched_gt_boxes.append(matched_gt_boxes_i)
            
        return gt_labels, gt_boxes
    
    def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes):
        """
        Returns:
            dict[str, Tensor]:
                dict mapping named loss to a scalar tensor.
                The keys are "loss_cls" and "loss_box_reg".
        """
        num_images = len(gt_labels)
        gt_labels = torch.stack(gt_labels) # (N, R)
        anchors = cat_boxes(anchors)  # (R, 4)
        gt_anchor_deltas = [self.box2box_transform.get_deltas(anchors, k) for k in gt_boxes]
        gt_anchor_deltas = torch.stack(gt_anchor_deltas)  # (N, R, 4)
        
        valid_mask = gt_labels >= 0
        pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes)
        num_pos_anchors = pos_mask.sum().item()
        
        # classification and regression loss.

In [8]:
def cat_boxes(boxes_list: List[Tensor]):
    """
    Concatenates a list of Boxes into a single tensor.
    """
    if len(boxes_list) == 0:
        return torch.empty(0)
    
    cat_boxes = torch.cat([b for b in boxes_list], dim=0)
    return cat_boxes

In [5]:
model = RetinaNetModel()

In [6]:
model

ResnetModel(
  (base): ResNet50(
    (stem): FastStem(
      (conv1_1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2))
      (conv1_2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv1_3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1_1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn1_2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (bn1_3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (res_block2_1): BottleNeckBlock(
      (block): Sequential(
        (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
 