In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
class MaxDepthPool2d(nn.Module):
    def __init__(self, pool_size=2):
        super().__init__()
        self.pool_size = pool_size

    def forward(self, x):
        shape = x.shape
        channels = shape[1] // self.pool_size
        new_shape = (shape[0], channels, self.pool_size, *shape[-2:])
        return torch.amax(x.view(new_shape), dim=2)

In [10]:
x = torch.randn((32, 16, 224, 224))
max_depth_pool2d = MaxDepthPool2d(pool_size=2)
max_depth_pool2d(x).shape

torch.Size([32, 8, 224, 224])

In [11]:
class SqueezeExcitation(nn.Module):
    def __init__(self, in_channels, squeeze_factor=8):
        super().__init__()
        squeeze_channels = in_channels // squeeze_factor
        self.feed_forward = nn.Sequential(
            nn.AdaptiveAvgPool2d(output_size=1),
            nn.Flatten(),
            nn.Linear(in_channels, squeeze_channels),
            nn.Mish(),
            nn.Linear(squeeze_channels, in_channels),
            nn.Sigmoid(),
        )

    def forward(self, x):
        calibration = self.feed_forward(x)
        return x * calibration.view(-1, x.shape[1], 1, 1)

In [12]:
in_channels = 64
x = torch.randint(0, 2, size=(32, in_channels, 224, 224), dtype=torch.float32)
se = SqueezeExcitation(in_channels=in_channels, squeeze_factor=8)
se(x)

tensor([[[[0.4623, 0.0000, 0.4623,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.4623, 0.0000,  ..., 0.4623, 0.0000, 0.0000],
          [0.4623, 0.4623, 0.0000,  ..., 0.4623, 0.4623, 0.4623],
          ...,
          [0.0000, 0.4623, 0.4623,  ..., 0.0000, 0.0000, 0.0000],
          [0.4623, 0.4623, 0.4623,  ..., 0.0000, 0.4623, 0.4623],
          [0.0000, 0.0000, 0.4623,  ..., 0.4623, 0.4623, 0.4623]],

         [[0.4518, 0.4518, 0.4518,  ..., 0.0000, 0.4518, 0.0000],
          [0.4518, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.4518],
          [0.0000, 0.4518, 0.4518,  ..., 0.4518, 0.4518, 0.4518],
          ...,
          [0.4518, 0.4518, 0.4518,  ..., 0.4518, 0.0000, 0.0000],
          [0.4518, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.4518],
          [0.0000, 0.4518, 0.4518,  ..., 0.0000, 0.0000, 0.4518]],

         [[0.0000, 0.0000, 0.5535,  ..., 0.0000, 0.0000, 0.5535],
          [0.5535, 0.0000, 0.5535,  ..., 0.5535, 0.0000, 0.0000],
          [0.5535, 0.0000, 0.0000,  ..., 0

In [13]:
class ResidualConnection(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        stride=1,
        squeeze_active=False,
        squeeze_factor=8,
    ):
        super().__init__()
        pad = kernel_size // 2
        self.squeeze_active = squeeze_active
        self.squeeze_excitation = SqueezeExcitation(out_channels, squeeze_factor)
        self.feed_forward = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=pad, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.Mish(),
            nn.Conv2d(out_channels, out_channels, kernel_size, padding=pad, bias=False),
            nn.BatchNorm2d(out_channels),
        )
        self.shortcut_connection = nn.Sequential()
        if not in_channels == out_channels or stride > 1:
            self.shortcut_connection = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        x_residual = self.feed_forward(x)
        x_shortcut = self.shortcut_connection(x)
        residual_output = F.mish(x_residual + x_shortcut)
        if self.squeeze_active:
            return self.squeeze_excitation(residual_output) + x_shortcut
        return residual_output

In [15]:
in_channels = 64
x = torch.randint(0, 2, size=(32, in_channels, 224, 224), dtype=torch.float32)
res = ResidualConnection(in_channels, 128, kernel_size=3, stride=2, squeeze_active=True)
res(x).shape

torch.Size([32, 128, 112, 112])

In [18]:
torch.manual_seed(42)
torch.cuda.manual_seed(42)

se_res_net = nn.Sequential(
    #
    nn.Conv2d(3, 32, kernel_size=5, stride=2, padding=2),
    nn.BatchNorm2d(num_features=32),
    nn.Mish(),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    ResidualConnection(32, 64, kernel_size=3, stride=1, squeeze_active=True),
    ResidualConnection(64, 64, kernel_size=3, stride=1, squeeze_active=True),
    MaxDepthPool2d(pool_size=2),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    ResidualConnection(32, 96, kernel_size=5, stride=1, squeeze_active=True),
    ResidualConnection(96, 96, kernel_size=5, stride=1, squeeze_active=True),
    MaxDepthPool2d(pool_size=2),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    ResidualConnection(48, 128, kernel_size=3, stride=1, squeeze_active=True),
    ResidualConnection(128, 128, kernel_size=3, stride=1, squeeze_active=True),
    MaxDepthPool2d(pool_size=4),
    nn.MaxPool2d(kernel_size=2, stride=2),
    #
    nn.Flatten(),
    #
    nn.Linear(32 * 7 * 7, 256, bias=False),
    nn.BatchNorm1d(num_features=256),
    nn.Mish(),
    nn.Dropout1d(0.4),
    #
    nn.Linear(256, 256, bias=False),
    nn.BatchNorm1d(num_features=256),
    nn.Mish(),
    nn.Dropout1d(0.4),
    #
    nn.Linear(256, 1),
).to(DEVICE)

torchinfo.summary(se_res_net, input_size=(32, 3, 224, 224), depth=2)

Layer (type:depth-idx)                        Output Shape              Param #
Sequential                                    [32, 1]                   --
├─Conv2d: 1-1                                 [32, 32, 112, 112]        2,432
├─BatchNorm2d: 1-2                            [32, 32, 112, 112]        64
├─Mish: 1-3                                   [32, 32, 112, 112]        --
├─MaxPool2d: 1-4                              [32, 32, 56, 56]          --
├─ResidualConnection: 1-5                     [32, 64, 56, 56]          --
│    └─Sequential: 2-1                        [32, 64, 56, 56]          55,552
│    └─Sequential: 2-2                        [32, 64, 56, 56]          2,176
│    └─SqueezeExcitation: 2-3                 [32, 64, 56, 56]          1,096
├─ResidualConnection: 1-6                     [32, 64, 56, 56]          --
│    └─Sequential: 2-4                        [32, 64, 56, 56]          73,984
│    └─Sequential: 2-5                        [32, 64, 56, 56]          --
│  

In [20]:
x = torch.randn((32, 3, 224, 224)).to(DEVICE, dtype=torch.float32)
se_res_net(x).shape

torch.Size([32, 1])