In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchinfo

In [42]:
class Identity(nn.Module):
  def __init__(self):
    super().__init__()

  def forward(self, x):
    return x

In [43]:
class DepthWise(nn.Module):
    def __init__(self, in_channels, kernel_size, stride):
        super().__init__()
        padding = (kernel_size - 1) // 2
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=kernel_size,
            stride=stride,
            groups=in_channels,
            padding=padding,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(num_features=in_channels)
        self.relu = nn.ReLU6(inplace=True)

    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))

class PointWise(nn.Module):
    def __init__(self, in_channels, out_channels, use_relu=False):
        super().__init__()
        self.use_relu = use_relu
        self.conv_1x1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
        )
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        if self.use_relu:
            self.relu = nn.ReLU6(inplace=True)

    def forward(self, x):
        x = self.bn(self.conv_1x1(x))
        return self.relu(x) if self.use_relu else x

class InvertedResidual(nn.Module):
    def __init__(self, in_channels, out_channels, expand_ratio, stride, downsample):
        super().__init__()
        hidden_dim = int(round(in_channels * expand_ratio))

        if downsample is not None:
          self.downsample = downsample
        else:
          self.downsample = Identity()

        self.conv1 = PointWise(
            in_channels=in_channels,
            out_channels=hidden_dim,
            use_relu=True,
        )

        self.conv2 = DepthWise(
            in_channels=hidden_dim,
            kernel_size=3,
            stride=stride,
        )

        self.conv3 = PointWise(
            in_channels=hidden_dim,
            out_channels=out_channels,
            use_relu=False,
        )

    def forward(self, x):
        shortcut = self.downsample(x)
        residual = self.conv3(self.conv2(self.conv1(x)))
        return residual + shortcut

In [44]:
# Trường hợp không biến đổi depth
inverted_bn = InvertedResidual(in_channels=160, out_channels=160, expand_ratio=6, stride=1, downsample=None)
torchinfo.summary(inverted_bn, (1, 160, 7, 7))

Layer (type:depth-idx)                   Output Shape              Param #
InvertedResidual                         [1, 160, 7, 7]            --
├─Identity: 1-1                          [1, 160, 7, 7]            --
├─PointWise: 1-2                         [1, 960, 7, 7]            --
│    └─Conv2d: 2-1                       [1, 960, 7, 7]            153,600
│    └─BatchNorm2d: 2-2                  [1, 960, 7, 7]            1,920
│    └─ReLU6: 2-3                        [1, 960, 7, 7]            --
├─DepthWise: 1-3                         [1, 960, 7, 7]            --
│    └─Conv2d: 2-4                       [1, 960, 7, 7]            8,640
│    └─BatchNorm2d: 2-5                  [1, 960, 7, 7]            1,920
│    └─ReLU6: 2-6                        [1, 960, 7, 7]            --
├─PointWise: 1-4                         [1, 160, 7, 7]            --
│    └─Conv2d: 2-7                       [1, 160, 7, 7]            153,600
│    └─BatchNorm2d: 2-8                  [1, 160, 7, 7]           

In [48]:
# Trường hợp biến đổi depth
downsample = nn.Conv2d(in_channels=160, out_channels=320, stride=1, kernel_size=3, padding=1)
inverted_block = InvertedResidual(in_channels=160, out_channels=320, expand_ratio=6, stride=1, downsample=downsample)
torchinfo.summary(inverted_block, (1, 160, 7, 7))

Layer (type:depth-idx)                   Output Shape              Param #
InvertedResidual                         [1, 320, 7, 7]            --
├─Conv2d: 1-1                            [1, 320, 7, 7]            461,120
├─PointWise: 1-2                         [1, 960, 7, 7]            --
│    └─Conv2d: 2-1                       [1, 960, 7, 7]            153,600
│    └─BatchNorm2d: 2-2                  [1, 960, 7, 7]            1,920
│    └─ReLU6: 2-3                        [1, 960, 7, 7]            --
├─DepthWise: 1-3                         [1, 960, 7, 7]            --
│    └─Conv2d: 2-4                       [1, 960, 7, 7]            8,640
│    └─BatchNorm2d: 2-5                  [1, 960, 7, 7]            1,920
│    └─ReLU6: 2-6                        [1, 960, 7, 7]            --
├─PointWise: 1-4                         [1, 320, 7, 7]            --
│    └─Conv2d: 2-7                       [1, 320, 7, 7]            307,200
│    └─BatchNorm2d: 2-8                  [1, 320, 7, 7]      

In [54]:
class BottleneckResidual(nn.Module):
    def __init__(self, in_channels, out_channels, squeeze_ratio, stride, downsample):
        super(BottleneckResidual, self).__init__()

        if downsample is not None:
          self.downsample = downsample
        else:
          self.downsample = Identity()

        hidden_dim = int(in_channels // squeeze_ratio)

        self.conv1 = nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, bias=False)
        self.bn1 = nn.BatchNorm2d(hidden_dim)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(hidden_dim)
        self.relu2 = nn.ReLU(inplace=True)

        self.conv3 = nn.Conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, bias=False)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu3 = nn.ReLU(inplace=True)

    def forward(self, x):
        shortcut = self.downsample(x)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)

        residual = x
        return residual + shortcut

In [46]:
# Trường hợp không biến đổi depth
block = BottleneckResidual(in_channels=160, out_channels=160, squeeze_ratio=6, stride=1, downsample=None)
torchinfo.summary(block, (1, 160, 7, 7))

Layer (type:depth-idx)                   Output Shape              Param #
BottleneckResidual                       [1, 160, 7, 7]            --
├─Identity: 1-1                          [1, 160, 7, 7]            --
├─Conv2d: 1-2                            [1, 26, 7, 7]             4,160
├─BatchNorm2d: 1-3                       [1, 26, 7, 7]             52
├─ReLU: 1-4                              [1, 26, 7, 7]             --
├─Conv2d: 1-5                            [1, 26, 7, 7]             6,084
├─BatchNorm2d: 1-6                       [1, 26, 7, 7]             52
├─ReLU: 1-7                              [1, 26, 7, 7]             --
├─Conv2d: 1-8                            [1, 160, 7, 7]            4,160
├─BatchNorm2d: 1-9                       [1, 160, 7, 7]            320
├─ReLU: 1-10                             [1, 160, 7, 7]            --
Total params: 14,828
Trainable params: 14,828
Non-trainable params: 0
Total mult-adds (M): 0.71
Input size (MB): 0.03
Forward/backward pass size 

In [55]:
# Trường hợp biến đổi depth
downsample = nn.Conv2d(in_channels=160, out_channels=320, stride=1, kernel_size=3, padding=1)
block = BottleneckResidual(in_channels=160, out_channels=320, squeeze_ratio=6, stride=1, downsample=downsample)
torchinfo.summary(block, (1, 160, 7, 7))

Layer (type:depth-idx)                   Output Shape              Param #
BottleneckResidual                       [1, 320, 7, 7]            --
├─Conv2d: 1-1                            [1, 320, 7, 7]            461,120
├─Conv2d: 1-2                            [1, 26, 7, 7]             4,160
├─BatchNorm2d: 1-3                       [1, 26, 7, 7]             52
├─ReLU: 1-4                              [1, 26, 7, 7]             --
├─Conv2d: 1-5                            [1, 26, 7, 7]             6,084
├─BatchNorm2d: 1-6                       [1, 26, 7, 7]             52
├─ReLU: 1-7                              [1, 26, 7, 7]             --
├─Conv2d: 1-8                            [1, 320, 7, 7]            8,320
├─BatchNorm2d: 1-9                       [1, 320, 7, 7]            640
├─ReLU: 1-10                             [1, 320, 7, 7]            --
Total params: 480,428
Trainable params: 480,428
Non-trainable params: 0
Total mult-adds (M): 23.51
Input size (MB): 0.03
Forward/backward pa