<a href="https://colab.research.google.com/github/manhcuong02/voc-semantic-segmentation/blob/main/models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install timm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting timm
  Downloading timm-0.9.2-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m38.4 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub (from timm)
  Downloading huggingface_hub-0.15.1-py3-none-any.whl (236 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m236.8/236.8 kB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting safetensors (from timm)
  Downloading safetensors-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m79.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: safetensors, huggingface-hub, timm
Successfully installed huggingface-hub-0.15.1 safetensors-0.3.1 timm-0.9.2


In [2]:
import torch

from torch import nn
from torch.nn import functional as F
import timm


In [None]:
def unet_block(in_channels, out_channels):
    return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1),
            nn.ReLU()
        )

# Unet
![](https://www.researchgate.net/publication/334287825/figure/fig2/AS:778191392210944@1562546694325/The-architecture-of-Unet.ppm)

In [None]:
class UNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.down_sample = nn.MaxPool2d(kernel_size = 2)
        self.up_sample = nn.Upsample(scale_factor = 2, mode = 'bilinear')

        self.block_down1 = unet_block(3, 64)
        self.block_down2 = unet_block(64, 128)
        self.block_down3 = unet_block(128, 256)
        self.block_down4 = unet_block(256, 512)

        self.neck = unet_block(512, 1024)


        self.block_up1 = unet_block(1024 + 512, 512)
        self.block_up2 = unet_block(512 + 256, 256)
        self.block_up3 = unet_block(256 + 128, 128)
        self.block_up4 = unet_block(128 + 64, 64)

        self.out = nn.Conv2d(64, num_classes, kernel_size = 1, stride = 1)

    def forward(self, x):
        # x: batch_size, channels, height, width
        x = self.block_down1(x)
        x1 = x.clone()

        x = self.down_sample(x)
        x = self.block_down2(x)
        x2 = x.clone()

        x = self.down_sample(x)
        x = self.block_down3(x)
        x3 = x.clone()

        x = self.down_sample(x)
        x = self.block_down4(x)
        x4 = x.clone()

        x = self.down_sample(x)
        x = self.neck(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x4, x], dim = 1
        )
        x = self.block_up1(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x3, x], dim = 1
        )
        x = self.block_up2(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x2, x], dim = 1
        )
        x = self.block_up3(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x1, x], dim = 1
        )
        x = self.block_up4(x)

        x = self.out(x)
        return x


# ResNet + Unet
- replace backbone to ResNet

In [None]:
class ResUNet(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.down_sample = nn.MaxPool2d(kernel_size = 2)
        self.up_sample = nn.Upsample(scale_factor = 2, mode = 'bilinear')

        # source: https://huggingface.co/docs/timm/feature_extraction#multiscale-feature-maps-feature-pyramid
        self.backbone = timm.create_model(model_name = 'resnet101', pretrained = True, features_only = True)

        self.block_up1 = unet_block(2048 + 1024, 1024)
        self.block_up2 = unet_block(1024 + 512, 512)
        self.block_up3 = unet_block(512 + 256, 256)
        self.block_up4 = unet_block(256 + 64, 64)

        self.out = nn.Conv2d(64, num_classes, kernel_size = 1, stride = 1)

    def forward(self, x):
        # x: batch_size, channels, height, width
        x1, x2, x3, x4, x5 = self.backbone(x)

        x = self.up_sample(x5)
        x = torch.cat(
            [x, x4], dim = 1
        )
        x = self.block_up1(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x, x3], dim = 1
        )
        x = self.block_up2(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x, x2], dim = 1
        )
        x = self.block_up3(x)

        x = self.up_sample(x)
        x = torch.cat(
            [x, x1], dim = 1
        )
        x = self.block_up4(x)

        x = self.up_sample(x)
        x = self.out(x)

        return x

# PSPNet
![](https://production-media.paperswithcode.com/methods/new_pspnet-eps-converted-to.jpg)

In [None]:
# Pyramid Pooling Module
class PPM(nn.Module):
    def __init__(self, in_channels, out_channels, bins):
        super().__init__()
        self.features = []
        for bin in bins:
            self.features.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(bin),
                nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ))
        self.features = nn.ModuleList(self.features)

    def forward(self, x):
        # x: batch, channels, height, width
        x_size = x.size()
        out = [x]
        for f in self.features:
            out.append(F.interpolate(f(x), x_size[2:], mode='bilinear', align_corners=True))
        return torch.cat(out, 1)


In [None]:
class PSPNet(nn.Module):
    def __init__(self, num_classes, dropout = 0.1, backbone = 'resnet101', bins = [1, 2, 3, 6]):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained = True, features_only = True)

        self.layer0 = nn.Sequential(
            self.backbone.conv1, self.backbone.bn1, self.backbone.act1, self.backbone.maxpool
        )

        self.layer1 = self.backbone.layer1
        self.layer2 = self.backbone.layer2
        self.layer3 = self.backbone.layer3
        self.layer4 = self.backbone.layer4


        # mở rộng Receptive field (kích thước vùng đầu vào để tính feature map).
        for n, m in self.layer3.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        for n, m in self.layer4.named_modules():
            if 'conv2' in n:
                m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
            elif 'downsample.0' in n:
                m.stride = (1, 1)

        feature_dims = 2048
        self.ppm = PPM(feature_dims, feature_dims//(len(bins)), bins)

        self.cls = nn.Sequential(
            nn.Conv2d(feature_dims*2, 512, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(p=dropout),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )
        if self.training:
            self.aux_layer = nn.Sequential(
                nn.Conv2d(1024, 256, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(256),
                nn.ReLU(inplace=True),
                nn.Dropout2d(p=dropout),
                nn.Conv2d(256, num_classes, kernel_size=1)
            )

    def forward(self, x):
        # B, C, H, W
        x_size = x.shape
        h,w = x_size[2:]

        # backbone: forward
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x_tmp = self.layer3(x)
        x = self.layer4(x_tmp)

        # Pyramid Pooling Module forward
        x = self.ppm(x)

        # head
        x = self.cls(x)

        x = F.interpolate(x, size = (h,w), mode = 'bilinear', align_corners = True)

        if self.training:
            aux = self.aux_layer(x_tmp)
            aux = F.interpolate(aux, size=(h, w), mode='bilinear', align_corners=True)
            return x, aux

        return x


# PSPNet + Unet
![](https://ars.els-cdn.com/content/image/1-s2.0-S0010482522011866-gr1.jpg)

In [None]:
class PSPUnet(nn.Module):
    def __init__(self, num_classes, bins = [1, 2, 3, 6], backbone = 'resnet101'):
        super().__init__()
        self.backbone = timm.create_model(backbone, pretrained = True, features_only = True)

        self.ppm1 = PPM(1024, 1024//len(bins), bins)
        self.ppm2 = PPM(512, 512//len(bins), bins)
        self.ppm3 = PPM(256, 256//len(bins), bins)
        self.ppm4 = PPM(64, 64//len(bins), bins)

        self.cls = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_classes, kernel_size=1)
        )

        self.up_sample = nn.Upsample(scale_factor = 2, mode = 'bilinear')
        self.block_up1 = unet_block(2048 + 2048, 1024)
        self.block_up2 = unet_block(1024 + 1024, 512)
        self.block_up3 = unet_block(512 + 512, 256)
        self.block_up4 = unet_block(256 + 64*2, 64)

    def forward(self, x):
        x1, x2, x3, x4, x5 = self.backbone(x)

        x = self.up_sample(x5)
        x = torch.concat(
            [x, self.ppm1(x4)], dim = 1
        )
        x = self.block_up1(x)

        x = self.up_sample(x)
        x = torch.concat(
            [x, self.ppm2(x3)], dim = 1
        )
        x = self.block_up2(x)

        x = self.up_sample(x)
        x = torch.concat(
            [x, self.ppm3(x2)], dim = 1
        )
        x = self.block_up3(x)

        x = self.up_sample(x)
        x = torch.concat(
            [x, self.ppm4(x1)], dim = 1
        )
        x = self.block_up4(x)

        x = self.up_sample(x)
        x = self.cls(x)

        return x



# Feature Pyramid Network
![](https://d3i71xaburhd42.cloudfront.net/7aae14c686757ba33d49c49aef4193038a4a7529/3-Figure2-1.png)

In [None]:
def FeaturePyramid(in_channel_list, out_channels, device = 'cuda'):
    if isinstance(device, str):
        if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
            device = torch.device(device)
        else:
            device = torch.device('cpu')

    features = []
    for in_channels in in_channel_list:
        features.append(
            nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 1, device = device)
        )
    return features

def conv3_block(in_channels, out_channels, num_block = 4, device = 'cuda'):
    if isinstance(device, str):
        if (device == 'cuda' or device == 'gpu') and torch.cuda.is_available():
            device = torch.device(device)
        else:
            device = torch.device('cpu')
    blocks = []
    for i in range(num_block):
        blocks.append(
            nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, device = device),
                nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, device = device)
            )
        )

    return blocks

In [None]:
class FPN(nn.Module):
    def __init__(self, backbone_name, out_fpn_channels, num_classes, device = 'cpu'):
        super().__init__()

        self.backbone = timm.create_model(backbone_name, pretrained = True, features_only = True, out_indices = (1,2,3,4))
        in_channel_list = self.backbone.feature_info.channels()
        self.fpn = FeaturePyramid(in_channel_list, out_fpn_channels, device = 'cpu')
        self.blocks = conv3_block(out_fpn_channels, 128, len(in_channel_list), device = 'cpu')

        self.conv_cls = nn.Conv2d(512, num_classes, kernel_size = 1, stride = 1)

    def forward(self, x):
        ori_shape = x.shape[2:]
        x = self.backbone(x)

        for i, conv in enumerate(self.fpn):
            x[i] = conv(x[i])

        x1, x2, x3, x4 = x

        x3 = F.interpolate(x4, size = x3.shape[2:], mode = 'bilinear') + x3
        x2 = F.interpolate(x3, size = x2.shape[2:], mode = 'bilinear') + x2
        x1 = F.interpolate(x2, size = x1.shape[2:], mode = 'bilinear') + x1

        x = [x1, x2, x3, x4]

        for i, block in enumerate(self.blocks):
            x[i] = block(x[i])

        for i in range(len(x)):
            x[i] = F.interpolate(x[i], ori_shape, mode = 'bilinear')

        x = torch.concat(x, dim = 1)

        x = self.conv_cls(x)

        return x

# Feature Pyramid Attention Network
![](https://user-images.githubusercontent.com/527241/58665416-1f1ad880-8331-11e9-9a2f-3cf289a1df96.png)

In [8]:
# source https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/decoders/pan/decoder.py

class ConvBnRelu(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        stride: int = 1,
        padding: int = 0,
        dilation: int = 1,
        groups: int = 1,
        bias: bool = True,
        add_relu: bool = True,
        interpolate: bool = False,
    ):
        super(ConvBnRelu, self).__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias,
            groups=groups,
        )
        self.add_relu = add_relu
        self.interpolate = interpolate
        self.bn = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        if self.add_relu:
            x = self.activation(x)
        if self.interpolate:
            x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
        return x

- FPA Block

![](https://miro.medium.com/v2/resize:fit:906/1*JhWWLQsxkJZoyuR-oIxAPA.png)

In [9]:
# source https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/decoders/pan/decoder.py
class FPABlock(nn.Module):
    def __init__(self, in_channels, out_channels, upscale_mode="bilinear"):
        super(FPABlock, self).__init__()

        self.upscale_mode = upscale_mode
        if self.upscale_mode == "bilinear":
            self.align_corners = True
        else:
            self.align_corners = False

        # global pooling branch
        self.branch1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvBnRelu(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            ),
        )

        # midddle branch
        self.mid = nn.Sequential(
            ConvBnRelu(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
            )
        )
        self.down1 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBnRelu(
                in_channels=in_channels,
                out_channels=1,
                kernel_size=7,
                stride=1,
                padding=3,
            ),
        )
        self.down2 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2),
        )
        self.down3 = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
            ConvBnRelu(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=1),
        )
        self.conv2 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=5, stride=1, padding=2)
        self.conv1 = ConvBnRelu(in_channels=1, out_channels=1, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
        h, w = x.size(2), x.size(3)
        b1 = self.branch1(x)
        upscale_parameters = dict(mode=self.upscale_mode, align_corners=self.align_corners)
        b1 = F.interpolate(b1, size=(h, w), **upscale_parameters)

        mid = self.mid(x)
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x3 = F.interpolate(x3, size=(h // 4, w // 4), **upscale_parameters)

        x2 = self.conv2(x2)
        x = x2 + x3
        x = F.interpolate(x, size=(h // 2, w // 2), **upscale_parameters)

        x1 = self.conv1(x1)
        x = x + x1
        x = F.interpolate(x, size=(h, w), **upscale_parameters)

        x = torch.mul(x, mid)
        x = x + b1
        return x

- GAU Block

![](https://miro.medium.com/v2/resize:fit:1052/1*hxW0710PijUYHe-BCimWoQ.png)

In [10]:
# source https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/decoders/pan/decoder.py
class GAUBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, upscale_mode: str = "bilinear"):
        super(GAUBlock, self).__init__()

        self.upscale_mode = upscale_mode
        self.align_corners = True if upscale_mode == "bilinear" else None

        self.conv1 = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            ConvBnRelu(
                in_channels=out_channels,
                out_channels=out_channels,
                kernel_size=1,
                add_relu=False,
            ),
            nn.Sigmoid(),
        )
        self.conv2 = ConvBnRelu(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1)

    def forward(self, x, y):
        """
        Args:
            x: low level feature
            y: high level feature
        """
        h, w = x.size(2), x.size(3)
        y_up = F.interpolate(y, size=(h, w), mode=self.upscale_mode, align_corners=self.align_corners)
        x = self.conv2(x)
        y = self.conv1(y)
        z = torch.mul(x, y)
        return y_up + z

In [17]:
class FPANet(nn.Module):
    def __init__(self, num_classes, backbone_name = 'resnet101'):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained = True, features_only = True)
        self.fpa_block = FPABlock(in_channels = 2048, out_channels = num_classes)

        self.gau_block1 = GAUBlock(in_channels = 1024, out_channels = num_classes)
        self.gau_block2 = GAUBlock(in_channels = 512, out_channels = num_classes)
        self.gau_block3 = GAUBlock(in_channels = 256, out_channels = num_classes)

    def forward(self, x):
        x_shape = x.shape
        x1, x2, x3, x4, x5 = self.backbone(x)

        x = self.fpa_block(x5)

        x = self.gau_block1(x4, x)
        x = self.gau_block2(x3, x)
        x = self.gau_block3(x2, x)

        x = F.interpolate(x, size = x_shape[2:], mode = 'bilinear')

        return x

In [18]:
# a = timm.create_model('resnet101', pretrained = True, features_only = True)
model = FPANet(num_classes = 5, backbone_name = 'resnet101')
x = torch.rand(2,3,256,256)
y = model(x)
y.shape

torch.Size([2, 5, 256, 256])

In [4]:
a = timm.create_model('resnet101', pretrained = True, features_only = True)
x = torch.rand(2,3,256,256)
y = a(x)
for i in y:
    print(i.shape)

torch.Size([2, 64, 128, 128])
torch.Size([2, 256, 64, 64])
torch.Size([2, 512, 32, 32])
torch.Size([2, 1024, 16, 16])
torch.Size([2, 2048, 8, 8])
