In [1]:
import ultralytics
from torchsummary import summary
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [58]:
class Conv(nn.Module):
    def __init__(self,conv,bn,act):
        super(Conv,self).__init__()
        self.conv = conv
        self.bn = bn
        self.act = act
    def forward(self,x):
        return nn.Sequential(self.conv,self.bn,self.act)(x)

class Bottleneck(nn.Module):
    def __init__(self,channel_size,shortcut=True):
        super(Bottleneck,self).__init__()
        self.cv1 = Conv(
                conv = nn.Conv2d(channel_size, channel_size, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                bn = nn.BatchNorm2d(channel_size, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                act = nn.SiLU(),
        )
        self.cv2 = Conv(
                conv = nn.Conv2d(channel_size, channel_size, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
                bn = nn.BatchNorm2d(channel_size, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True ),
                act = nn.SiLU(),
        )
        self.shortcut = shortcut
    def forward(self,x):
        if self.shortcut==True:
            x_1 = self.cv1(x)
            x_2 = self.cv2(x_1)
            return x_2+x
        else:
            x = self.cv1(x)
            x = self.cv2(x)
            return x 

In [87]:
class C2f(nn.Module):
    def __init__(self,n,in_channel1,out_channel1,in_channel2,out_channel2,shortcut=True):
        super(C2f,self).__init__()
        self.cv1 = Conv(
                conv = nn.Conv2d(in_channel1, out_channel1, kernel_size=(1, 1), stride=(1, 1), bias=False),
                bn = nn.BatchNorm2d(out_channel1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                act = nn.SiLU(),
        )

        self.cv2 = Conv(
                conv = nn.Conv2d(in_channel2, out_channel2, kernel_size=(1, 1), stride=(1, 1), bias=False),
                bn = nn.BatchNorm2d(out_channel2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                act = nn.SiLU(),
        )
        self.n = n
        
        self.ModuleList = [Bottleneck(channel_size=out_channel2//2,shortcut=shortcut) for i in range(n)]

    def forward(self,x):
        x = self.cv1(x)
        x_half1 = x[:,:x.size(1)//2,:,:]
        x_half2 = x[:,x.size(1)//2:,:,:]

        output = torch.cat([x_half1,x_half2],dim=1)
        for bottleneck in self.ModuleList:
            x_half1 = bottleneck(x_half1)
            output = torch.cat([output,x_half1],dim=1)
        return self.cv2(output)

class SPPF(nn.Module):
    def __init__(self):
        super(SPPF,self).__init__()
        self.cv1 = Conv(
            nn.Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
        )
        self.maxpool = nn.MaxPool2d(kernel_size=5, stride=1, padding=2, dilation=1, ceil_mode=False)
        self.cv2 = Conv(
            nn.Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False),
            nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
          )
    def forward(self,x):
        x_1 = self.cv1(x)
        x_2 = self.maxpool(x_1)
        x_cats = torch.cat([x_1,x_2],dim=1)
        x_3 = self.maxpool(x_2)
        x_cats = torch.cat([x_cats,x_3],dim=1)
        x_4 = self.maxpool(x_3)
        x_cats = torch.cat([x_cats,x_4],dim=1)
        return self.cv2(x_cats)

class Detect(nn.Module):
    def __init__(self,in_channel):
        super(Detect,self).__init__()
        self.bb = nn.Sequential(Conv(
            nn.Conv2d(in_channel, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU(),
          ),
          Conv(
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
          ),
          nn.Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1)))

        # for cls
        self.cls =  nn.Sequential(Conv(
            nn.Conv2d(in_channel, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU(),
          ),
          Conv(
            nn.Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False),
            nn.BatchNorm2d(80, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
          ),
          nn.Conv2d(80, 80, kernel_size=(1, 1), stride=(1, 1)))
        
    def forward(self,x):
        regout = self.bb(x)
        clsout = self.cls(x)
        return (regout,clsout)

In [93]:
class YOLOv8n():
    def __init__(self):
        super(YOLOv8n,self).__init__()
        self.model_component1 = nn.Sequential(
        Conv(
            nn.Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
        ),
        Conv(
            nn.Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
        ),
        C2f(n=1,in_channel1=32,out_channel1=32,in_channel2=48,out_channel2=32,shortcut=True),
        Conv(
            nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
        ),
        C2f(n=2,in_channel1=64,out_channel1=64,in_channel2=128,out_channel2=64,shortcut=True)
        )
    
        self.model_component2 = nn.Sequential(
            Conv(
                nn.Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
                nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.SiLU()
            ),
            C2f(n=2,in_channel1=128,out_channel1=128,in_channel2=256,out_channel2=128,shortcut=True)
        )
    
        self.model_component3 = nn.Sequential(
            Conv(
              nn.Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
              nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
              nn.SiLU()
            ),
            C2f(n=1,in_channel1=256,out_channel1=256,in_channel2=384,out_channel2=256,shortcut=True),
            SPPF()
        )
    
        self.head_upsample1 = nn.Upsample(scale_factor=2.0, mode='nearest')
        self.head_c2f1 = C2f(n=1,in_channel1=384,out_channel1=128,in_channel2=192,out_channel2=128,shortcut=False)
        self.head_c2f2 = C2f(n=1,in_channel1=192,out_channel1=64,in_channel2=96,out_channel2=64,shortcut=False)
    
        self.conv1 = Conv(
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
        )
    
        self.conv2 = Conv(
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
            nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.SiLU()
        )
    
        self.head_c2f3 = C2f(n=1,in_channel1=192,out_channel1=128,in_channel2=192,out_channel2=128,shortcut=False)
        self.head_c2f4 = C2f(n=1,in_channel1=384,out_channel1=256,in_channel2=384,out_channel2=256,shortcut=False)

        self.Modulelist_detect = [Detect(in_channel=64),Detect(in_channel=128),Detect(in_channel=256)] 

    def forward(self,x):
        x_skip1 = self.model_component1(x)
        
        x_skip2 = self.model_component2(x_skip1)
        
        x_skip3 = self.model_component3(x_skip2)

        #head
        x = self.head_upsample1(x_skip3)
        x = torch.cat([x_skip2,x],dim=1)
        x_skip4 = self.head_c2f1(x)
        x = self.head_upsample1(x_skip4)
        x = torch.cat([x_skip1,x],dim=1)
        x = self.head_c2f2(x)
        
        outreg1,outcls1 = self.Modulelist_detect[0](x)
        
        x = self.conv1(x)
        x = torch.cat([x_skip4,x],dim=1)
        x = self.head_c2f3(x)
        
        outreg2,outcls2 = self.Modulelist_detect[1](x)
        
        x = self.conv2(x)
        x = torch.cat([x_skip3,x],dim=1)
        x = self.head_c2f4(x)
        
        outreg3,outcls3 = self.Modulelist_detect[2](x)

        return (outreg1,outcls1,outreg2,outcls2,outreg3,outcls3)

In [94]:
model = YOLOv8n()

In [95]:
a = torch.randn(1,3,640,640)
print(a.shape)
b,c,d,e,f,g=model.forward(a)
print(b.shape,c.shape,d.shape,e.shape,f.shape,g.shape)

torch.Size([1, 3, 640, 640])
torch.Size([1, 64, 80, 80]) torch.Size([1, 80, 80, 80]) torch.Size([1, 64, 40, 40]) torch.Size([1, 80, 40, 40]) torch.Size([1, 64, 20, 20]) torch.Size([1, 80, 20, 20])
