# Feature Pyramid

![feature pyramid](https://lilianweng.github.io/lil-log/assets/images/featurized-image-pyramid.png)

This is a more advanced approach.
We can stride down and up _within a repeated block_.
Skip connections prevent losing too much information.

In [None]:
from typing import *

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torch.optim as optim

import numpy as np
from torchvision import datasets, transforms

from tqdm import tqdm

import PIL

In [None]:
class SkipSequential(nn.Sequential):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_prime = super().forward(x)
        return torch.cat([x, x_prime], dim=1)

In [None]:
def pyramid_block(in_channels: int, out_channels: int) -> nn.Module:
    # outer -> middle -> inner -> middle -> outer
    
    # input_size: out_channels
    # output_size: out_channels + out_channels
    inner = SkipSequential(
        nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(2, 2),
            padding=(1, 1),
            padding_mode='reflect',
        ),
        nn.Tanh(),
        nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            padding=(1, 1),
            padding_mode='reflect',
        ),
        nn.Tanh(),
        nn.Upsample(scale_factor=2)
    )
    
    # input_size: in_channels
    # output_size: in_channels + out_channels + out_channels
    middle = SkipSequential(
        nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(2, 2),
            padding=(1, 1),
            padding_mode='reflect',
        ),
        nn.Tanh(),
        inner,
        nn.Upsample(scale_factor=2)
    )
    
    # input_size: in_channels
    # output_size: out_channels
    outer = nn.Sequential(
        middle,
        nn.Conv2d(
            in_channels=in_channels + out_channels + out_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(2, 2),
            padding=(1, 1),
            padding_mode='reflect',
        ),
        nn.Tanh(),
    )
    return outer

---

Because this shrinks and expands, there is a point past which this will not work.
This block is also very sensitive to the input image size - rounding errors frequently lead to misalignment between the different layers, which then breaks the skip connection.

In [None]:
pyramid_block(3, 16)(torch.zeros(1, 3, 32, 32)).shape

In [None]:
pyramid_block(3, 16)(torch.zeros(1, 3, 4, 4)).shape