In [31]:
# import
import torch
import torch.nn as nn

In [32]:
"""
Return: depth scaling factor (d), width scaling factor (w), resolution scaling factor (r)
"""
def params(version):
    if version == 'n':
        return 1/3, 1/4, 2.0
    elif version == 's':
        return 1/3, 1/2, 2.0
    elif version == 'm':
        return 2/3, 3/4, 1.5
    elif version == 'l':
        return 1.0, 1.0, 1.0
    elif version == 'x':
        return 1.0, 1.25, 1.0

# Components

## Atoms

### 1. Conv
![Conv](images/conv.jpg)

In [33]:
class Conv(nn.Module):
    """
    in_c: int, number of input channels (typically 3 for RGB images)
    out_c: int, number of output channels (number of filters)
    k: int, size of the kernel
    s: int, stride of the kernel
    p: int, padding of the kernel
    g: int, number of groups
    act: bool, whether to use activation function SiLU
    """
    def __init__(self, in_c, out_c, k = 3, s = 1, p = 1, g = 1, act = True):
        super().__init__()

        # Conv2d: a convolutional layer
        """
        in_c: int, number of input channels
        out_c: int, number of output channels
        k: int, size of the kernel
        s: int, stride of the kernel
        p: int, padding of the kernel
        g: int, number of groups
        bias: bool, whether to use bias
        """
        self.conv = nn.Conv2d(in_c, out_c, k, s, p, bias = False, groups = g)

        # BatchNorm2d: a normalization layer
        """
        num_features: int, number of features
        eps: float, a value added to the denominator for numerical stability
        momentum: float, the value used for the running_mean and running_var computation
        """
        self.bn = nn.BatchNorm2d(num_features = out_c, eps = 0.001, momentum = 0.03)

        # SiLU: an activation function
        """
        inplace: bool, whether to modify the input directly
        """
        self.act = nn.SiLU(inplace = True) if act else nn.Identity()


    # Conv2d -> BatchNorm2d -> SiLU
    def forward(self, x):
        return self.act(self.bn(self.conv(x)))
    


# Sanity check (First Convolutional Layer)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    print("(0):")

    """
    input channels: 3
    output channels: 64 * width scaling factor (0.5)
    kernel size: 3
    stride: 2
    padding: 1
    groups: 1
    activation: True
    """
    print(Conv(in_c = 3, out_c = int(64*w), k = 3, s = 2, p = 1, g = 1, act = True))

    """
    batch size: 1
    input channels: 3
    image height: 640
    image width: 640
    """
    print(Conv(in_c = 3, out_c = int(64*w), k = 3, s = 2, p = 1, g = 1, act = True)(torch.randn(1, 3, 640, 640)).shape)

(0):
Conv(
  (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
  (act): SiLU(inplace=True)
)
torch.Size([1, 32, 320, 320])


### 2. Bottleneck
![Conv](images/bottleneck.jpg)

In [34]:
class Bottleneck(nn.Module):
    """
    in_c: int, number of input channels
    out_c: int, number of output channels
    shortcut: bool, whether to use a residual connection
    """
    def __init__(self, in_c, out_c, shortcut=True):
        super().__init__()

        # Convolutional layers
        self.conv1 = Conv(in_c, out_c, k = 3, s = 1, p = 1)
        self.conv2 = Conv(out_c, out_c, k = 3, s = 1, p = 1)

        # shortcut: a residual connection
        self.shortcut = shortcut


    # Conv1 -> Conv2 -> Shortcut
    def forward(self, x):
        x_in = x
        x = self.conv1(x)
        x = self.conv2(x)
        
        if self.shortcut:
            x = x + x_in

        return x
    


# Sanity check (First Bottleneck in the First C2f block)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    print("(1):")

    """
    input channels: 64 * width scaling factor (0.5)
    output channels: 64 * width scaling factor (0.5)
    shortcut: True
    """
    print(Bottleneck(in_c = int(64*w), out_c = int(64*w), shortcut = True))

    """
    batch size: 1
    input channels: 64 * width scaling factor (0.5)
    image height: 224
    image width: 224
    """
    print(Bottleneck(in_c = int(64*w), out_c = int(64*w), shortcut = True)(torch.randn(1, int(64*w), 224, 224)).shape)

(1):
Bottleneck(
  (conv1): Conv(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (conv2): Conv(
    (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
)
torch.Size([1, 32, 224, 224])


### 3. Upsample

In [35]:
class Upsample(nn.Module):
    """
    scale_factor: int, scaling factor
    mode: str, interpolation mode
    """
    def __init__(self, scale_factor = 2, mode = 'nearest'):
        super().__init__()

        self.scale_factor = scale_factor
        self.mode = mode


    # Upsample
    def forward(self, x):
        return nn.functional.interpolate(x, scale_factor = self.scale_factor, mode = self.mode)

### 4. DFL

In [36]:
class DFL(nn.Module):
    """
    ch: int, number of channels
    """
    def __init__(self, ch = 16):
        super().__init__()
        
        self.ch = ch
        
        # Convolutional layer
        self.conv = nn.Conv2d(in_channels = ch, out_channels = 1, kernel_size = 1, bias = False).requires_grad_(False)
        
        x = torch.arange(ch, dtype = torch.float).view(1, ch, 1, 1)
        self.conv.weight.data[:] = torch.nn.Parameter(x)


    def forward(self, x):
        b, c, a = x.shape
        x = x.view(b, 4, self.ch, a).transpose(1, 2)

        # Softmax
        x = x.softmax(1)

        x = self.conv(x)

        return x.view(b, 4, a)
    


# Sanity check (DFL)
if __name__ == "__main__":
    print("(dfl):")

    """
    number of channels: 16
    """
    print(DFL(ch = 16))

    """
    batch size: 1
    number of channels: 16
    """
    print(DFL(ch = 16)(torch.randn(1, 64, 128)).shape)

(dfl):
DFL(
  (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
)
torch.Size([1, 4, 128])


## Molecules

### 1. C2f
![C2f](images/c2f.jpg)

In [37]:
class C2f(nn.Module):
    """
    in_c: int, number of input channels
    out_c: int, number of output channels
    num_bottlenecks: int, number of bottlenecks
    shortcut: bool, whether to use a residual connection
    """
    def __init__(self, in_c, out_c, num_bottlenecks, shortcut = True):
        super().__init__()
        
        self.mid_channels = out_c // 2
        self.num_bottlenecks = num_bottlenecks

        # Convolutional layers
        self.conv1 = Conv(in_c, out_c, k = 1, s = 1, p = 0)
        self.conv2 = Conv((num_bottlenecks + 2) * out_c // 2, out_c, k = 1, s = 1, p = 0)
        
        # Bottleneck Sequence
        self.m = nn.ModuleList([Bottleneck(self.mid_channels, self.mid_channels, shortcut) for _ in range(num_bottlenecks)])
    

    # Conv1 -> Split -> Bottleneck Sequence -> Concat -> Conv2
    def forward(self,x):
        x = self.conv1(x)

        x1, x2 = x[:,:x.shape[1]//2,:,:], x[:,x.shape[1]//2:,:,:]
        outputs = [x1, x2]

        for i in range(self.num_bottlenecks):
            x1 = self.m[i](x1)
            outputs.insert(0,x1)

        outputs = torch.cat(outputs, dim = 1)

        out = self.conv2(outputs)

        return out
    


# Sanity check (First C2f block)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    print("(2):")

    """
    input channels: 128 * width scaling factor (0.5)
    output channels: 128 * width scaling factor (0.5)
    number of bottlenecks: 3
    shortcut: True
    """
    print(C2f(in_c = int(128*w), out_c = int(128*w), num_bottlenecks = 1, shortcut = True))

    """
    batch size: 1
    input channels: 128 * width scaling factor (0.5)
    image height: 160
    image width: 160
    """
    print(C2f(in_c = int(128*w), out_c = int(128*w), num_bottlenecks = 1, shortcut = True)(torch.randn(1, int(128*w), 160, 160)).shape)

(2):
C2f(
  (conv1): Conv(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (conv2): Conv(
    (conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (m): ModuleList(
    (0): Bottleneck(
      (conv1): Conv(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (conv2): Conv(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
    )
  )
)
to

### 2. SPPF
![SPPF](images/sppf.jpg)

In [38]:
class SPPF(nn.Module):
    """
    in_c: int, number of input channels
    out_c: int, number of output channels
    k: int, size of the kernel
    """
    def __init__(self, in_c, out_c, k = 5):
        super().__init__()

        hidden_c = in_c // 2

        # Convolutional layers
        self.conv1 = Conv(in_c, hidden_c, k = 1, s = 1, p = 0)
        self.conv2 = Conv(4 * hidden_c, out_c, k = 1, s = 1, p = 0)

        # MaxPool2d: a pooling layer
        """
        k: int, size of the kernel
        s: int, stride of the kernel
        p: int, padding of the kernel
        dilation: int, spacing between kernel elements
        ceil_mode: bool, whether to use the ceil function to calculate the output size
        """
        self.m = nn.MaxPool2d(kernel_size = k, stride = 1, padding = k // 2, dilation = 1, ceil_mode = False)
    

    # Conv1 -> MaxPool2ds -> Concat -> Conv2
    def forward(self,x):
        x = self.conv1(x)

        y1 = self.m(x)
        y2 = self.m(y1)
        y3 = self.m(y2)

        y = torch.cat([x,y1,y2,y3], dim = 1)
        
        y = self.conv2(y)

        return y
    


# Sanity check (SPPF block)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    print("(9):")

    """
    input channels: 512 * width scaling factor (0.5) * resolution scaling factor (2.0)
    output channels: 512 * width scaling factor (0.5) * resolution scaling factor (2.0)
    kernel size: 5
    """
    print(SPPF(in_c = int(512*w*r), out_c = int(512*w*r), k = 5))

    """
    batch size: 1
    input channels: 512 * width scaling factor (0.5) * resolution scaling factor (2.0)
    image height: 20
    image width: 20
    """
    print(SPPF(in_c = int(512*w*r), out_c = int(512*w*r), k = 5)(torch.randn(1, int(512*w*r), 20, 20)).shape)

(9):
SPPF(
  (conv1): Conv(
    (conv): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (conv2): Conv(
    (conv): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (m): MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
)
torch.Size([1, 512, 20, 20])


## Organisms

### 1. Backbone
![Backbone](images/backbone.jpg)

In [39]:
class Backbone(nn.Module):
    def __init__(self, version, in_c = 3, shortcut = True):
        super().__init__()
        d, w, r = params(version)

        # Convolutional and C2f blocks
        self.conv_0 = Conv(in_c, int(64*w), k = 3, s = 2, p = 1)
        self.conv_1 = Conv(int(64*w), int(128*w), k = 3, s = 2, p = 1)
        self.c2f_2 = C2f(int(128*w), int(128*w), num_bottlenecks = int(3*d), shortcut = True)
        self.conv_3 = Conv(int(128*w), int(256*w), k = 3, s = 2, p = 1)
        self.c2f_4 = C2f(int(256*w), int(256*w), num_bottlenecks = int(6*d), shortcut = True)
        self.conv_5 = Conv(int(256*w), int(512*w), k = 3, s = 2, p = 1)
        self.c2f_6 = C2f(int(512*w), int(512*w), num_bottlenecks = int(6*d), shortcut = True)
        self.conv_7 = Conv(int(512*w), int(512*w*r), k = 3, s = 2, p = 1)
        self.c2f_8 = C2f(int(512*w*r), int(512*w*r), num_bottlenecks = int(3*d), shortcut = True)

        # SPPF block
        self.sppf = SPPF(int(512*w*r), int(512*w*r))
    

    # Conv0 -> Conv1 -> C2f2 -> Conv3 -> C2f4 -> Conv5 -> C2f6 -> Conv7 -> C2f8 -> SPPF
    def forward(self, x):
        x = self.conv_0(x)
        x = self.conv_1(x)
        x = self.c2f_2(x)
        x = self.conv_3(x)
        out1 = self.c2f_4(x)
        x = self.conv_5(out1)
        out2 = self.c2f_6(x)
        x = self.conv_7(out2)
        x = self.c2f_8(x)
        out3 = self.sppf(x)

        return out1, out2, out3

In [40]:
# Sanity check (Backbone)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    """
    version: s
    input channels: 3
    shortcut: True
    """
    print(Backbone(version, in_c = 3, shortcut = True))

    """
    batch size: 1
    input channels: 3
    image height: 640
    image width: 640
    """
    x = torch.randn(1, 3, 640, 640)
    out1, out2, out3 = Backbone(version, in_c = 3, shortcut = True)(x)
    print(out1.shape)
    print(out2.shape)
    print(out3.shape)

Backbone(
  (conv_0): Conv(
    (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (conv_1): Conv(
    (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
    (act): SiLU(inplace=True)
  )
  (c2f_2): C2f(
    (conv1): Conv(
      (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (conv2): Conv(
      (conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (m): ModuleList(
      (0): Bottleneck(
        (conv

### 2. Neck
![Neck](images/neck.jpg)

In [41]:
class Neck(nn.Module):
    """
    version: str, version of the model
    """
    def __init__(self, version):
        super().__init__()
        d, w, r = params(version)

        # Upsample block
        self.up = Upsample()

        # Convolutional and C2f blocks
        self.c2f_1 = C2f(in_c = int(512*w*(1+r)),  out_c = int(512*w), num_bottlenecks = int(3*d), shortcut = False)
        self.c2f_2 = C2f(in_c = int(768*w),  out_c = int(256*w), num_bottlenecks = int(3*d), shortcut = False)
        self.conv_1 = Conv(in_c = int(256*w), out_c = int(256*w), k = 3, s = 2,  p = 1)
        self.c2f_3 = C2f(in_c = int(768*w),  out_c = int(512*w), num_bottlenecks = int(3*d), shortcut = False)
        self.conv_2 = Conv(in_c = int(512*w), out_c = int(512*w), k = 3, s = 2,  p = 1)
        self.c2f_4 = C2f(in_c = int(512*w*(1+r)),  out_c = int(512*w*r), num_bottlenecks = int(3*d), shortcut = False)


    # Upsample -> Concat -> C2f1 -> Upsample -> Concat -> C2f2 -> Upsample -> Concat -> C2f3 -> Upsample -> Concat -> C2f4
    def forward(self, x_res_1, x_res_2, x):    
        res_1 = x
        x = self.up(x)
        x = torch.cat([x, x_res_2], dim = 1)
        res_2 = self.c2f_1(x)
        x = self.up(res_2)
        x = torch.cat([x, x_res_1], dim = 1)
        out_1 = self.c2f_2(x)
        x = self.conv_1(out_1)
        x = torch.cat([x, res_2], dim = 1)
        out_2 = self.c2f_3(x)
        x = self.conv_2(out_2)
        x = torch.cat([x, res_1], dim = 1)
        out_3 = self.c2f_4(x)

        return out_1, out_2, out_3

In [42]:
# Sanity check (Neck)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    """
    version: s
    """
    print(Neck(version))

    """
    batch size: 1
    input channels: 512 * width scaling factor (0.5) * resolution scaling factor (2.0)
    image height: 40
    image width: 40
    """
    x = torch.rand((1,3,640,640))
    out1, out2, out3 = Backbone(version = 's')(x)
    out_1, out_2, out_3 = Neck(version = 's')(out1, out2, out3)
    print(out_1.shape)
    print(out_2.shape)
    print(out_3.shape)

Neck(
  (up): Upsample()
  (c2f_1): C2f(
    (conv1): Conv(
      (conv): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (conv2): Conv(
      (conv): Conv2d(384, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(256, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (m): ModuleList(
      (0): Bottleneck(
        (conv1): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
          (act): SiLU(inplace=True)
        )
        (conv2): Conv(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(128, eps=0.001, momentum=0.03, affine=True,

### 3. Head
![Head](images/head2.jpg)

In [43]:
class Head(nn.Module):
    """
    version: str, version of the model
    ch: int, number of channels
    num_classes: int, number of classes
    """
    def __init__(self, version, ch = 16, num_classes = 80):
        super().__init__()
        d, w, r = params(version = version)

        self.ch = ch
        self.coordinates = self.ch * 4
        self.nc = num_classes
        self.no = self.ch if version == 'x' else 0
        self.stride = torch.zeros(3)

        v = self.nc if version == 'n' else int(256*w)
        
        # Bounding Box
        self.box = nn.ModuleList([
            nn.Sequential(Conv(int(256*w), self.coordinates + self.no, k = 3, s = 1, p = 1), 
                          Conv(self.coordinates + self.no, self.coordinates + self.no, k = 3, s = 1, p = 1), 
                          nn.Conv2d(self.coordinates + self.no, self.coordinates, kernel_size = 1, stride = 1)), 

            nn.Sequential(Conv(int(512*w), self.coordinates + self.no, k = 3, s = 1, p = 1), 
                          Conv(self.coordinates + self.no, self.coordinates + self.no, k = 3, s = 1, p = 1), 
                          nn.Conv2d(self.coordinates + self.no, self.coordinates, kernel_size = 1, stride = 1)), 

            nn.Sequential(Conv(int(512*w*r), self.coordinates + self.no, k = 3, s = 1, p = 1), 
                          Conv(self.coordinates + self.no, self.coordinates + self.no, k = 3, s = 1, p = 1), 
                          nn.Conv2d(self.coordinates + self.no, self.coordinates, kernel_size = 1, stride = 1))
        ])

        # Classification
        self.cls = nn.ModuleList([
            nn.Sequential(Conv(int(256*w), v, k = 3, s = 1, p = 1), 
                          Conv(v, v, k = 3, s = 1, p = 1), 
                          nn.Conv2d(v, self.nc, kernel_size = 1, stride = 1)), 

            nn.Sequential(Conv(int(512*w), v, k = 3, s = 1, p = 1), 
                          Conv(v, v, k = 3, s = 1, p = 1), 
                          nn.Conv2d(v, self.nc, kernel_size = 1, stride = 1)), 

            nn.Sequential(Conv(int(512*w*r), v, k = 3, s = 1, p = 1), 
                          Conv(v, v, k = 3, s = 1, p = 1), 
                          nn.Conv2d(v, self.nc, kernel_size = 1, stride = 1))
        ])

        # DFL
        self.dfl = DFL()


    def forward(self, x):
        for i in range(len(self.box)):
            box = self.box[i](x[i])
            cls = self.cls[i](x[i])
            x[i] = torch.cat((box,cls),dim = 1)

        if self.training:
            return x
        
        anchors, ss = (i.transpose(0, 1) for i in self.make_anchors(x, self.s))

        x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim = 2)
        
        box, cls = x.split(split_size = (4 * self.ch, self.nc), dim = 1)

        a, b = self.dfl(box).chunk(2, 1)
        a = anchors.unsqueeze(0) - a
        b = anchors.unsqueeze(0) + b
        box = torch.cat(tensors = ((a + b) / 2, b - a), dim = 1)
        
        return torch.cat(tensors = (box * ss, cls.sigmoid()), dim = 1)


    def make_anchors(self, x, ss, offset = 0.5):
        assert x is not None
        anchor_tensor, s_tensor = [],[]
        dtype, device  =  x[0].dtype, x[0].device
        for i, s in enumerate(ss):
            _, _, h, w  =  x[i].shape
            sx  =  torch.arange(end = w, device = device, dtype = dtype) + offset 
            sy  =  torch.arange(end = h, device = device, dtype = dtype) + offset
            sy, sx  =  torch.meshgrid(sy, sx)
            anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
            s_tensor.append(torch.full((h * w, 1), s, dtype = dtype, device = device))

        return torch.cat(anchor_tensor), torch.cat(s_tensor)

In [44]:
# Sanity check (Head)
if __name__ == "__main__":
    version = 's'
    d, w, r = params(version)

    """
    version: s
    number of channels: 16
    number of classes: 80
    """
    print(Head(version))

    """
    """
    output = Head(version = 's')([out_1, out_2, out_3])
    print(output[0].shape)
    print(output[1].shape)
    print(output[2].shape)

Head(
  (box): ModuleList(
    (0): Sequential(
      (0): Conv(
        (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (2): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): Sequential(
      (0): Conv(
        (conv): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       

# YOLO

In [45]:
class Yolo(nn.Module):
    def __init__(self,version):
        super().__init__()
        
        self.backbone = Backbone(version = version)
        self.neck = Neck(version = version)
        self.head = Head(version = version)

    def forward(self, x):
        x = self.backbone(x)
        x = self.neck(x[0], x[1], x[2])
        return self.head(list(x))
    
model = Yolo(version = 's')
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")
print(model)

11.16656 million parameters
Yolo(
  (backbone): Backbone(
    (conv_0): Conv(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (conv_1): Conv(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): SiLU(inplace=True)
    )
    (c2f_2): C2f(
      (conv1): Conv(
        (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (act): SiLU(inplace=True)
      )
      (conv2): Conv(
        (conv): Conv2d(96, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn): BatchNorm2d(64, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
        (