# Implementation of RetinaNet Head.

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [2]:
from typing import Optional
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

In [4]:
from base import BaseModel
from layers.wrappers import conv3x3

## Retina Head
paper: [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)

Section 4. RetinaNet Detector -> Classification Subnet & Box Regression Subnet.

In [13]:
class RetinaNetHead(BaseModel):
    """
    Implements RetinaNet head. see: https://arxiv.org/abs/1708.02002
    """
    
    def __init__(self, num_classes, num_anchors=9, num_channels=256):
        super().__init__()
        self.num_classes = num_classes
        self.num_anchors = num_anchors
        self.num_channels = num_channels
        
        self.classifier_subnet = nn.Sequential(
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_anchors * self.num_classes)
        )
        
        self.regressor_subnet = nn.Sequential(
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_channels),
            nn.ReLU(inplace=True),
            conv3x3(self.num_channels, self.num_anchors * 4)
        )
        
        
    def forward(self, P3, P4, P5, P6, P7):
        
        logits = {
            "p3": self.classifier_subnet(P3),
            "p4": self.classifier_subnet(P4),
            "p5": self.classifier_subnet(P5),
            "p6": self.classifier_subnet(P6),
            "p7": self.classifier_subnet(P7)
            
        }
        bbox_reg = {
            "p3": self.regressor_subnet(P3),
            "p4": self.regressor_subnet(P4),
            "p5": self.regressor_subnet(P5),
            "p6": self.regressor_subnet(P6),
            "p7": self.regressor_subnet(P7)
        }
        
        return logits, bbox_reg

In [14]:
model = RetinaNetHead(20)

In [15]:
model

RetinaNetHead(
  (classifier_subnet): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 180, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (regressor_subnet): Sequential(
    (0): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU(inplace=True)
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=T