# FPN

论文链接：[Feature Pyramid Networks for Object Detection](https://openaccess.thecvf.com/content_cvpr_2017/papers/Lin_Feature_Pyramid_Networks_CVPR_2017_paper.pdf)

---

小尺寸的卷积核得到的特征图（feature maps）分辨率较高，大尺寸得到的较低。

低分辨率特征图捕获了更多的图像全局信息，表达了更丰富的语义；而高分辨率特征图更关注局部信息，提供了更准确的空间信息

FPN 的目标是将高分辨率和低分辨率特征图结合起来，增强特征空间信息的准确性和丰富的语义。

---

下图左边部分为 bottom-up pathway，通常使用骨干网络，如：ResNet。右边部分为 top-down pathway，用于融合不同尺度的特征图。

左边和右边使用 1*1 卷积连接（lateral connections），卷积作用为将左边特征图的维度调整到与右边一致。

右边特征图使用 Upsampling 调整到和左边特征图对应的分辨率。

通过 element-wise addition 将左右特征图进行相加。

![](img/FPN/FPN-architecture.png)


In [342]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [343]:
class FPN(nn.Module):
    def __init__(self):
        super().__init__()
        # Bottom-up layers
        self.bottomlayer1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.bottomlayer2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.bottomlayer3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)

        # Lateral layers
        self.latlayer1 = nn.Conv2d(256, 32, kernel_size=1, stride=1, padding=0)
        self.latlayer2 = nn.Conv2d(128, 32, kernel_size=1, stride=1, padding=0)
        self.latlayer3 = nn.Conv2d(64, 32, kernel_size=1, stride=1, padding=0)

        # Top-down layers
        self.toplayer1 = nn.Conv2d(256, 32, kernel_size=3, stride=1, padding=1)
        self.toplayer2 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.toplayer3 = nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Bottom-up
        c1 = self.bottomlayer1(x)   # 1/2
        c2 = self.bottomlayer2(c1)  # 1/4
        c3 = self.bottomlayer3(c2)  # 1/8

        # Lateral connections
        lat1 = self.latlayer1(c3)  # ignore
        lat2 = self.latlayer2(c2)
        lat3 = self.latlayer3(c1)

        # Top-down
        p1 = self.toplayer1(c3)

        p2 = F.interpolate(p1, scale_factor=2, mode='nearest') + lat2
        p2 = self.toplayer2(p2)

        p3 = F.interpolate(p2, scale_factor=2, mode='nearest') + lat3
        p3 = self.toplayer3(p3)

        return p1, p2, p3

In [344]:
x = torch.randn(1, 3, 256, 256)

model = FPN()
p1, p2, p3 = model(x)

print(p1.shape, p2.shape, p3.shape, sep='\n')

torch.Size([1, 32, 32, 32])
torch.Size([1, 32, 64, 64])
torch.Size([1, 32, 128, 128])
