# Implementation of the FPN network assuming that the Backbone is a ResNet50 network.

In [2]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [3]:
from typing import Optional
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
from base import BaseModel

## FPN Network
paper: [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)

In [8]:
def lateral_connection(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1)

def conv3x3(in_channels: int, out_channels: int, stride: Optional[int]=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)

In [18]:
class LateralUpsampleMerge(nn.Module):
    """Merge bottom-up path lateral connection with top-down upsampled path"""
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.lat_conv = lateral_connection(in_channels, out_channels)
    
    def forward(self, x, feature_map):
        lat_out = self.lat_conv(feature_map)
        return lat_out + F.interpolate(x, scale_factor=2.0, mode="nearest")

In [21]:
class FPN50(BaseModel):
    """
    Implements FPN network assuming a ResNet50 backbone.
    """
    
    def __init__(self, out_features=256):
        super().__init__()
        
        # Stage 5:
        self.lateral5 = lateral_connection(2048, out_features)
        ## self.lat_merge5 = LateralUpsampleMerge(2048, out_features)
        self.conv5 = conv3x3(out_features, out_features)
        
        # Stage 4:
        ## self.lateral4 = lateral_connection(1024, out_features)
        self.lat_merge4 = LateralUpsampleMerge(1024, out_features)
        self.conv4 = conv3x3(out_features, out_features)
        
        # Stage 3:
        ## self.lateral3 = lateral_connection(512, out_features)
        self.lat_merge3 = LateralUpsampleMerge(512, out_features)
        self.conv3  = conv3x3(out_features, out_features)
        
        # Stage 2:
        ## self.lateral2 = lateral_connection(256, out_features)
        self.lat_merge2 = LateralUpsampleMerge(256, out_features)
        self.conv2 = conv3x3(out_features, out_features)  
        
    def forward(self, C2, C3, C4, C5):
        # Stage 5 forward.
        out = self.lateral5(C5)
        P5 = self.conv5(out)  
        
        # Stage 4 forward.
        out = self.lat_merge4(out, C4)
        P4 = self.conv4(out)
        
        # Stage 3 forward.
        out = self.lat_merge3(out, C3)
        P3 = self.conv3(out)
        
        # Stage 2 forward.
        out = self.lat_merge2(out, C2)
        P2 = self.conv2(out)
        
        return P2, P3, P4, P5
        

In [22]:
model = FPN50()
model

FPN50(
  (lateral5): Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
  (conv5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lat_merge4): LateralUpsampleMerge(
    (lat_conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lat_merge3): LateralUpsampleMerge(
    (lat_conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (lat_merge2): LateralUpsampleMerge(
    (lat_conv): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
  )
  (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)