# Implementation of RetinaNet Head.

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [2]:
from typing import Optional
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from base import BaseModel
from layers.wrappers import conv3x3

## Retina Head
paper: [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)

Section 4. RetinaNet Detector -> Classification Subnet & Box Regression Subnet.

In [None]:
class RetinaNetHead(BaseModel):
    """
    Implements RetinaNet head. see: https://arxiv.org/abs/1708.02002
    """
    
    def __init__(self, num_classes, num_anchors=9, num_channels=256):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.num_channels = num_channels
        
        self.classifier_subnet = self._create_subnet() 
        
        self.regressor_subnet = self._create_subnet(reg=True)
        
    def _create_subnet(self, reg=False):
        if reg:
            last_layer = conv3x3(self.num_channels, self.num_anchors * 4)
            
        else:
            last_layer = conv3x3(self.num_channels, self.num_anchors * self.num_classes)
            
        return nn.Sequential(
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            last_layer
        )
        
    def forward(self, P3, P4, P5, P6, P7):
        
        logits = {
            "p3": self.classifier_subnet(P3),
            "p4": self.classifier_subnet(P4),
            "p5": self.classifier_subnet(P5),
            "p6": self.classifier_subnet(P6),
            "p7": self.classifier_subnet(P7)
            
        }
        bbox_reg = {
            "p3": self.regressor_subnet(P3),
            "p4": self.regressor_subnet(P4),
            "p5": self.regressor_subnet(P5),
            "p6": self.regressor_subnet(P6),
            "p7": self.regressor_subnet(P7)
        }
        
        return logits, bbox_reg

In [19]:
model = RetinaNetHead(20)

In [20]:
model

RetinaNetHead(
  (classifier_subnet): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 180, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (regressor_subnet): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=T

## Test Model

In [2]:
import torch

In [3]:
from model.backbone.retina_meta import RetinaNetFPN50, RetinaNetHead
from model.backbone.resnet import ResNet50

In [4]:
backbone = ResNet50()
fpn_backbone = RetinaNetFPN50()

In [5]:
data = torch.randn((64, 3, 512, 512))

In [6]:
_, C3, C4, C5 = backbone(data)

In [7]:
P3, P4, P5, P6, P7 = fpn_backbone(C3, C4, C5)

In [8]:
model = RetinaNetHead(20)

In [9]:
pred_logits, bbox_regs = model(P3, P4, P5, P6, P7)

In [10]:
num_anchors = 9
num_classes = 20

### Logits

In [14]:
assert pred_logits["p3"].shape == (64, num_anchors * num_classes, 64, 64)

In [15]:
assert pred_logits["p4"].shape == (64, num_anchors * num_classes, 32, 32)

In [16]:
assert pred_logits["p5"].shape == (64, num_anchors * num_classes, 16, 16)

In [17]:
assert pred_logits["p6"].shape == (64, num_anchors * num_classes, 8, 8)

In [19]:
assert pred_logits["p7"].shape == (64, num_anchors * num_classes, 4, 4)

### Boxes

In [20]:
assert bbox_regs["p3"].shape == (64, 4 * num_anchors, 64, 64)

In [21]:
assert bbox_regs["p4"].shape == (64, 4 * num_anchors, 32, 32)

In [22]:
assert bbox_regs["p5"].shape == (64, 4 * num_anchors, 16, 16)

In [23]:
assert bbox_regs["p6"].shape == (64, 4 * num_anchors, 8, 8)

In [24]:
assert bbox_regs["p7"].shape == (64, 4 * num_anchors, 4, 4)