# Implementation of the FPN Backbone for the RetinaNet Object Detector.

In [3]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [4]:
from typing import Optional
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F

In [8]:
from base import BaseModel
from layers.upsample import LateralUpsampleMerge
from layers.wrappers import conv3x3, conv1x1

## Retina FPN Backbone Variant

paper: [Focal Loss for Dense Object Detection](https://arxiv.org/abs/1708.02002)

Section 4. RetinaNet Detector -> Feature Pyramid Network Backbone.

In [9]:
class RetinaNetFPN50(BaseModel):
    """
    Implements FPN network assuming a ResNet50 backbone.
    """
    
    def __init__(self, out_features=256):
        super().__init__()
        
        # Stage 7:
        self.conv7_up = conv3x3(out_features, out_features, stride=2)
        
        # Stage 6:
        self.conv6_up = conv3x3(2048, out_features, stride=2)
        
        # Stage 5:
        self.lateral5 = conv1x1(2048, out_features)
        self.conv5 = conv3x3(out_features, out_features)
        
        # Stage 4:
        self.lat_merge4 = LateralUpsampleMerge(1024, out_features)
        self.conv4 = conv3x3(out_features, out_features)
        
        # Stage 3:
        self.lat_merge3 = LateralUpsampleMerge(512, out_features)
        self.conv3  = conv3x3(out_features, out_features)
        
    def forward(self, C2, C3, C4, C5):
        
        # Stage 6 and 7 forward.
        P6 = self.conv6_up(C5)
        P7 = self.conv7_up(F.relu(P6))
                           
        # Stage 5 forward.
        out = self.lateral5(C5)
        P5 = self.conv5(out)  
        
        # Stage 4 forward.
        out = self.lat_merge4(out, C4)
        P4 = self.conv4(out)
        
        # Stage 3 forward.
        out = self.lat_merge3(out, C3)
        P3 = self.conv3(out)
        
        return P3, P4, P5, P6, P7
        

In [10]:
model = RetinaNetFPN50()

In [11]:
model

RetinaNetFPN50(
  (conv7_up): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (conv6_up): Conv2d(2048, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
  (conv5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lat_merge4): LateralUpsampleMerge(
    (lat_conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lat_merge3): LateralUpsampleMerge(
    (lat_conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)