<a href="https://colab.research.google.com/github/mahtabtech/ResNet-with-FPN/blob/main/ResNet_with_FPN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Building a Feature Pyramid Network (FPN) from a Pre-trained ResNet Model**

In this example, we will load a pre-trained ResNet-50 model, extract feature maps from different layers, and use them to build a Feature Pyramid Network (FPN). An FPN is a multi-scale feature extractor that combines feature maps from different layers of a neural network to improve the network's ability to detect objects at different scales.

# **Step 1: Load the Pre-trained ResNet Model**
First, we load a pre-trained ResNet-50 model from the PyTorch library.

In [1]:
import torchvision
import torch.nn as nn
import torch
import torch.nn.functional as F

# Load the pre-trained ResNet-50 model
model = torchvision.models.resnet50(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 149MB/s]


# **Step 2: Extract Feature Maps from Different Layers**
We will extract feature maps from three different layers of the ResNet model: an early stage, a middle stage, and a late stage. These layers will provide feature maps with different levels of abstraction.

Early Stage: Extract features from the initial layers, including the first convolutional layer and layer1.

Middle Stage: Extract features up to layer3.

Late Stage: Extract features up to layer4, which contains the most abstract representations.

In [2]:
# Extract feature maps from different layers of the ResNet model
layer1 = nn.Sequential(*list(model.children())[:5])  # Early stage feature map
layer3 = nn.Sequential(*list(model.children())[:7])  # Middle stage feature map
layer4 = nn.Sequential(*list(model.children())[:8])  # Late stage feature map

# **Step 3: Generate an Example Input and Obtain Feature Maps**
To check the feature maps' sizes before merging them into the FPN, we create a random input tensor (simulating an image) and pass it through the extracted layers.

In [3]:
# Example input tensor (e.g., a 224x224 image)
x = torch.randn(1, 3, 224, 224)

# Get feature maps from the different layers
feat1 = layer1(x)  # Early stage feature map
feat3 = layer3(x)  # Middle stage feature map
feat4 = layer4(x)  # Late stage feature map

# Print the shapes of the feature maps
print(feat1.shape, feat3.shape, feat4.shape)

torch.Size([1, 256, 56, 56]) torch.Size([1, 1024, 14, 14]) torch.Size([1, 2048, 7, 7])


# **Step 4: Build the Feature Pyramid Network (FPN)**
We build an FPN that merges the extracted feature maps. The FPN uses 1x1 convolutions to reduce the number of channels to a consistent size (e.g., 256) and then applies upsampling and merging operations to create a multi-scale feature representation.

In [4]:
# Define the Feature Pyramid Network (FPN)
class CorrectedFPN(nn.Module):
    def __init__(self):
        super(CorrectedFPN, self).__init__()

        # 1x1 convolutions to reduce the number of channels to 256
        self.conv4 = nn.Conv2d(2048, 256, kernel_size=1)  # For the deepest feature map
        self.conv3 = nn.Conv2d(1024, 256, kernel_size=1)  # For the middle feature map
        self.conv1 = nn.Conv2d(256, 256, kernel_size=1)   # For the shallowest feature map

        # 3x3 convolution for further refinement
        self.conv_out = nn.Conv2d(256, 256, kernel_size=3, padding=1)

    def forward(self, feat1, feat3, feat4):
        # Convert feat4 to 256 channels (deepest feature map)
        p4 = self.conv4(feat4)  # Start with the deepest feature map

        # Upsample p4 and merge with the middle stage feature map (feat3)
        p4_upsampled = F.interpolate(p4, size=(feat3.shape[2], feat3.shape[3]), mode='nearest')
        p3 = self.conv3(feat3) + p4_upsampled
        p3 = self.conv_out(p3)  # Refine the combined feature map

        # Upsample p3 and merge with the early stage feature map (feat1)
        p3_upsampled = F.interpolate(p3, size=(feat1.shape[2], feat1.shape[3]), mode='nearest')
        p1 = self.conv1(feat1) + p3_upsampled
        p1 = self.conv_out(p1)  # Refine the combined feature map

        return p1, p3, p4  # Return the FPN feature maps

# Instantiate the FPN model
fpn = CorrectedFPN()

# Forward pass through the FPN with the extracted feature maps
p1, p3, p4 = fpn(feat1, feat3, feat4)

# Print the shapes of the output feature maps
print(p1.shape, p3.shape, p4.shape)

torch.Size([1, 256, 56, 56]) torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 7, 7])


# **Explanation of the FPN Architecture**

**1x1 Convolutions:** These are used to standardize the number of channels in the feature maps to 256, which is necessary for merging.

**Upsampling**: The deeper feature maps are upsampled using nearest-neighbor interpolation to match the spatial dimensions of the higher-resolution feature maps.

**Merging:** The upsampled feature maps are merged with the higher-resolution feature maps to create a multi-scale representation.

**3x3 Convolution:** A final 3x3 convolution is applied to each merged feature map to refine the output.

By combining feature maps from different stages of the ResNet model, the FPN effectively captures both fine details and abstract features, improving the model's ability to handle objects at different scales.