In [None]:
import torch.nn as nn
import torch
from layers import *

class DetectionModel(nn.Module):
    def __init__(self, depth=0.5, width=0.25, max_channels=1024, nc=80):
        """
        Initialize the DetectionModel.

        Args:
            depth (float): Depth multiplier for scaling the model.
            width (float): Width multiplier for scaling the model.
            max_channels (int): Maximum number of channels.
            nc (int): Number of classes for detection.
        """
        super().__init__()

        # Backbone
        self.model = nn.Sequential(
            Conv(3, int(64 * width), 3, 2),  # 0-P1/2
            Conv(int(64 * width), int(128 * width), 3, 2),  # 1-P2/4
            C3k2(int(128 * width), int(256 * width), n=int(2 * depth), c3k=False, e=0.25),
            Conv(int(256 * width), int(256 * width), 3, 2),  # 3-P3/8
            C3k2(int(256 * width), int(512 * width), n=int(2 * depth), c3k=False, e=0.25),
            ##
            Conv(int(512 * width), int(512 * width), 3, 2),  # 5-P4/16
            C3k2(int(512 * width), int(512 * width), n=int(2 * depth), c3k=True),
            ##
            Conv(int(512 * width), int(1024 * width), 3, 2),  # 7-P5/32
            C3k2(int(1024 * width), int(1024 * width), n=int(2 * depth), c3k=True),
            SPPF(int(1024 * width), int(1024 * width), k=5),  # 9
            C2PSA(int(1024 * width), int(1024 * width), n=int(2 * depth)),
            )  # 10
            # Head
        self.neck = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),  # -1 
            Concat(dimension=1),  # cat backbone P4
            C3k2(int(1536 * width), int(512 * width), n=int(2 * depth), c3k=False),  # 13

            nn.Upsample(scale_factor=2, mode="nearest"),  # -1
            Concat(dimension=1),  # cat backbone P3
            C3k2(int(1024 * width), int(256 * width), n=int(2 * depth), c3k=False),  # 16 (P3/8-small)

            Conv(int(256 * width), int(256 * width), 3, 2),  # -1
            Concat(dimension=1),  # cat head P4
            C3k2(int(768 * width), int(512 * width), n=int(2 * depth), c3k=False),  # 19 (P4/16-medium)

            Conv(int(512 * width), int(512 * width), 3, 2),  # -1
            Concat(dimension=1),  # cat head P5
            C3k2(int(1536 * width), int(1024 * width), n=int(2 * depth), c3k=True),  # 22 (P5/32-large)
        )

        self.head = Detect(nc=nc, ch=[int(256 * width), int(512 * width), int(1024 * width)])  # Detect(P3, P4, P5)

    def forward(self, x):
        """
        Forward pass through the DetectionModel.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor.
        """
        x_out = []
        for n, layer in enumerate(self.model):
            x = layer(x)
            if n==4 or n==6:
                x_out.append(x)
        x_out.append(x)
        p4, p6, p10 = x_out

        for n, layer in enumerate(self.neck):
            
            if n==1 and isinstance(layer, Concat):
                x = layer((x, p6))

            elif n==4 and isinstance(layer, Concat):
                x = layer((x, p4))

            elif n==7 and isinstance(layer, Concat):
                x = layer((x, p12))
                
            elif n==10 and isinstance(layer, Concat):
                x = layer((x, p10))

            elif n==2 and isinstance(layer, C3k2):
                x = layer(x)
                p12 = x.clone()

            elif n==5 and isinstance(layer, C3k2):
                x = layer(x)
                p14 = x.clone()

            elif n==8 and isinstance(layer, C3k2):
                x = layer(x)
                p17 = x.clone()
                
            
            else:
                x = layer(x)
        print(x.shape, p14.shape, p17.shape)
        # x = self.head([x,p14,p17])
        return x

In [40]:
import torch.nn as nn
import torch
from layers import *

class DetectionModel(nn.Module):
    def __init__(self, depth=0.5, width=0.25, max_channels=1024, nc=80):
        """
        Initialize the DetectionModel.

        Args:
            depth (float): Depth multiplier for scaling the model.
            width (float): Width multiplier for scaling the model.
            max_channels (int): Maximum number of channels.
            nc (int): Number of classes for detection.
        """
        super().__init__()

        # Backbone
        self.model = nn.Sequential(
            Conv(3, int(64 * width), 3, 2),  # 0-P1/2
            Conv(int(64 * width), int(128 * width), 3, 2),  # 1-P2/4
            C3k2(int(128 * width), int(256 * width), n=int(2 * depth), c3k=False, e=0.25),
            Conv(int(256 * width), int(256 * width), 3, 2),  # 3-P3/8
            C3k2(int(256 * width), int(512 * width), n=int(2 * depth), c3k=False, e=0.25),
            ##
            Conv(int(512 * width), int(512 * width), 3, 2),  # 5-P4/16
            C3k2(int(512 * width), int(512 * width), n=int(2 * depth), c3k=True),
            ##
            Conv(int(512 * width), int(1024 * width), 3, 2),  # 7-P5/32
            C3k2(int(1024 * width), int(1024 * width), n=int(2 * depth), c3k=True),
            SPPF(int(1024 * width), int(1024 * width), k=5),  # 9
            C2PSA(int(1024 * width), int(1024 * width), n=int(2 * depth)),
        #     )  # 10
        #     # Head
        # self.neck = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),  # -1 
            Concat(dimension=1),  # cat backbone P4
            C3k2(int(1536 * width), int(512 * width), n=int(2 * depth), c3k=False),  # 13

            nn.Upsample(scale_factor=2, mode="nearest"),  # -1
            Concat(dimension=1),  # cat backbone P3
            C3k2(int(1024 * width), int(256 * width), n=int(2 * depth), c3k=False),  # 16 (P3/8-small)

            Conv(int(256 * width), int(256 * width), 3, 2),  # -1
            Concat(dimension=1),  # cat head P4
            C3k2(int(768 * width), int(512 * width), n=int(2 * depth), c3k=False),  # 19 (P4/16-medium)

            Conv(int(512 * width), int(512 * width), 3, 2),  # -1
            Concat(dimension=1),  # cat head P5
            C3k2(int(1536 * width), int(1024 * width), n=int(2 * depth), c3k=True),  # 22 (P5/32-large)
        )

        self.head = Detect(nc=nc, ch=[int(256 * width), int(512 * width), int(1024 * width)])  # Detect(P3, P4, P5)

    def forward(self, x):
        """
        Forward pass through the DetectionModel.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor.
        """
        for n, layer in enumerate(self.model):
            if n==4:
                x = layer(x)
                p4 = x.clone() 
            elif n==6:
                x = layer(x)
                p6 = x.clone()
            elif n==10:
                x = layer(x)
                p10 = x.clone()
                
            elif n==12 and isinstance(layer, Concat):
                x = layer((x, p6))

            elif n==15 and isinstance(layer, Concat):
                x = layer((x, p4))

            elif n==18 and isinstance(layer, Concat):
                x = layer((x, p13))
                
            elif n==21 and isinstance(layer, Concat):
                x = layer((x, p10))

            elif n==13 and isinstance(layer, C3k2):
                x = layer(x)
                p13 = x.clone()

            elif n==16 and isinstance(layer, C3k2):
                x = layer(x)
                p16 = x.clone()

            elif n==19 and isinstance(layer, C3k2):
                x = layer(x)
                p19 = x.clone()
            else:
                x = layer(x)
        # print(x.shape, p14.shape, p17.shape)
        # x = self.head([x,p14,p17])
        return [x,p19, p16]

In [41]:
model = DetectionModel()
    # print(model)
# weights = torch.load('/home/bibhabasum/projects/IIIT/ultralytics/model_state_dict.pth') # Load weights

# model.load_state_dict(weights, strict=False)  # Load model weights

image = torch.randn(1, 3, 640, 640)  # Example input
output = model(image)  # Forward pass
# print(output)  # Print output shape
 # Print output shape

In [42]:
[out.shape for out in output]

[torch.Size([1, 256, 20, 20]),
 torch.Size([1, 128, 40, 40]),
 torch.Size([1, 64, 80, 80])]

## Final Detect Fix

In [48]:
import torch.nn as nn
import torch
from layers import *

class DetectionModel(nn.Module):
    def __init__(self, depth=0.5, width=0.25, max_channels=1024, nc=80):
        """
        Initialize the DetectionModel.

        Args:
            depth (float): Depth multiplier for scaling the model.
            width (float): Width multiplier for scaling the model.
            max_channels (int): Maximum number of channels.
            nc (int): Number of classes for detection.
        """
        super().__init__()

        # Backbone
        self.model = nn.Sequential(
            Conv(3, int(64 * width), 3, 2),  # 0-P1/2
            Conv(int(64 * width), int(128 * width), 3, 2),  # 1-P2/4
            C3k2(int(128 * width), int(256 * width), n=int(2 * depth), c3k=False, e=0.25),
            Conv(int(256 * width), int(256 * width), 3, 2),  # 3-P3/8
            C3k2(int(256 * width), int(512 * width), n=int(2 * depth), c3k=False, e=0.25),
            ##
            Conv(int(512 * width), int(512 * width), 3, 2),  # 5-P4/16
            C3k2(int(512 * width), int(512 * width), n=int(2 * depth), c3k=True),
            ##
            Conv(int(512 * width), int(1024 * width), 3, 2),  # 7-P5/32
            C3k2(int(1024 * width), int(1024 * width), n=int(2 * depth), c3k=True),
            SPPF(int(1024 * width), int(1024 * width), k=5),  # 9
            C2PSA(int(1024 * width), int(1024 * width), n=int(2 * depth)),
        #     )  # 10
        #     # Head
        # self.neck = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),  # -1 
            Concat(dimension=1),  # cat backbone P4
            C3k2(int(1536 * width), int(512 * width), n=int(2 * depth), c3k=False),  # 13

            nn.Upsample(scale_factor=2, mode="nearest"),  # -1
            Concat(dimension=1),  # cat backbone P3
            C3k2(int(1024 * width), int(256 * width), n=int(2 * depth), c3k=False),  # 16 (P3/8-small)

            Conv(int(256 * width), int(256 * width), 3, 2),  # -1
            Concat(dimension=1),  # cat head P4
            C3k2(int(768 * width), int(512 * width), n=int(2 * depth), c3k=False),  # 19 (P4/16-medium)

            Conv(int(512 * width), int(512 * width), 3, 2),  # -1
            Concat(dimension=1),  # cat head P5
            C3k2(int(1536 * width), int(1024 * width), n=int(2 * depth), c3k=True),  # 22 (P5/32-large)
            Detect(nc=nc, ch=[int(256 * width), int(512 * width), int(1024 * width)])
        )

        # self.head = Detect(nc=nc, ch=[int(256 * width), int(512 * width), int(1024 * width)])  # Detect(P3, P4, P5)

    def forward(self, x):
        """
        Forward pass through the DetectionModel.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            (torch.Tensor): Output tensor.
        """
        for n, layer in enumerate(self.model):
            if n==4:
                x = layer(x)
                p4 = x.clone() 
            elif n==6:
                x = layer(x)
                p6 = x.clone()
            elif n==10:
                x = layer(x)
                p10 = x.clone()
                
            elif n==12 and isinstance(layer, Concat):
                x = layer((x, p6))

            elif n==15 and isinstance(layer, Concat):
                x = layer((x, p4))

            elif n==18 and isinstance(layer, Concat):
                x = layer((x, p13))
                
            elif n==21 and isinstance(layer, Concat):
                x = layer((x, p10))

            elif n==13 and isinstance(layer, C3k2):
                x = layer(x)
                p13 = x.clone()

            elif n==16 and isinstance(layer, C3k2):
                x = layer(x)
                p16 = x.clone()

            elif n==19 and isinstance(layer, C3k2):
                x = layer(x)
                p19 = x.clone()

            ## head
            elif n==23 and isinstance(layer, Detect):
                out = layer([p16, p19, x])
            else:
                x = layer(x)
            
                
        return out

In [52]:
model = DetectionModel()
    # print(model)
weights = torch.load('/home/bibhabasum/projects/IIIT/ultralytics/model_state_dict.pth') # Load weights

model.load_state_dict(weights, strict=False)  # Load model weights

image = torch.randn(1, 3, 640, 640)  # Example input
output = model(image)  # Forward pass
# print(output)  # Print output shape
 # Print output shape

  weights = torch.load('/home/bibhabasum/projects/IIIT/ultralytics/model_state_dict.pth') # Load weights


In [58]:
output["one2many"][0].shape

torch.Size([1, 144, 80, 80])