In [40]:
# import
import torch
import torch.nn as nn

In [41]:
"""
Return: depth scaling factor (d), width scaling factor (w), resolution scaling factor (r)
"""
def params(version):
    if version == 'n':
        return 1/3, 1/4, 2.0
    elif version == 's':
        return 1/3, 1/2, 2.0
    elif version == 'm':
        return 2/3, 3/4, 1.5
    elif version == 'l':
        return 1.0, 1.0, 1.0
    elif version == 'x':
        return 1.0, 1.25, 1.0

# 1. Backbone

## (a) Conv
![Conv](images/conv.jpg)

In [42]:
class Conv(nn.Module):
    """
    in_c: int, number of input channels (typically 3 for RGB images)
    out_c: int, number of output channels (number of filters)
    k: int, size of the kernel
    s: int, stride of the kernel
    p: int, padding of the kernel
    g: int, number of groups
    act: bool, whether to use activation function SiLU
    """
    def __init__(self, in_c, out_c, k = 3, s = 1, p = 1, g = 1, act = True):
        super().__init__()

        # Conv2d: a convolutional layer
        """
        in_c: int, number of input channels
        out_c: int, number of output channels
        k: int, size of the kernel
        s: int, stride of the kernel
        p: int, padding of the kernel
        g: int, number of groups
        bias: bool, whether to use bias
        """
        self.conv = nn.Conv2d(in_c, out_c, k, s, p, bias = False, groups = g)

        # BatchNorm2d: a normalization layer
        """
        num_features: int, number of features
        eps: float, a value added to the denominator for numerical stability
        momentum: float, the value used for the running_mean and running_var computation
        """
        self.bn = nn.BatchNorm2d(num_features = out_c, eps = 0.001, momentum = 0.03)

        # SiLU: an activation function
        """
        inplace: bool, whether to modify the input directly
        """
        self.act = nn.SiLU(inplace = True) if act else nn.Identity()


    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
    


# Sanity check
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    print("(0):")

    """
    input channels: 3
    output channels: 32
    kernel size: 3
    stride: 2
    padding: 1
    groups: 1
    activation: True
    """
    print(Conv(in_c = 3, out_c = int(64*w), k = 3, s = 2, p = 1, g = 1, act = True))

    """
    batch size: 1
    input channels: 3
    image height: 224
    image width: 224
    """
    print(Conv(in_c = 3, out_c = int(64*w), k = 3, s = 2, p = 1, g = 1, act = True)(torch.randn(1, 3, 640, 640)).shape)

(0):
Conv(
  (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
  (act): SiLU(inplace=True)
)
torch.Size([1, 32, 320, 320])
