In [50]:
import torch
import torch.nn as nn

In [51]:
config = [
    (32, 3, 1), # Tuple : (out_channels,kernel_size,stride)
    (64, 3, 2),
    ["B", 1], # List : (residual block, number_of_repeats)
    (128, 3, 2),
    ["B", 2],
    (256, 3, 2),
    ["B", 8],
    (512, 3, 2),
    ["B", 8],
    (1024, 3, 2),
    ["B", 4],  # To this point is Darknet-53
    (512, 1, 1),
    (1024, 3, 1),
    "S", # branch where scale predicts
    (256, 1, 1),

    # Upsampling is essential in CNNs for tasks requiring the restoration of spatial resolution, 
    # generation of high-resolution images, and pixel-wise prediction alignment

    # often used in tasks that require high-resolution output,
    #  such as image segmentation, super-resolution, and generative models.
    "U", # upsampling
    
    (256, 1, 1),
    (512, 3, 1),
    "S",
    (128, 1, 1),
    "U",
    (128, 1, 1),
    (256, 3, 1),
    "S",
]

In [52]:
class CNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,bn_act=True,**kwargs):
        super(CNNBlock,self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels,out_channels=out_channels,bias=not bn_act,**kwargs)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)
        self.bn_act = bn_act
    
    def forward(self,x):
        if self.bn_act:
            return self.leakyrelu(self.batch_norm(self.conv(x)))
        else:
            return self.conv(x)

# ResidualBlock
### The ResidualBlock is more complex and is used to create residual networks. It consists of:

# Multiple CNNBlock instances.
### An optional residual connection that adds the input of the block to the output, which helps in training very deep networks by mitigating the vanishing gradient problem.

In [53]:
class ResidualBlock(nn.Module):
    def __init__(self,channels,use_residual=True,num_repeats=1):
        super(ResidualBlock,self).__init__()
        self.layers = nn.ModuleList()
        for repeat in range(num_repeats):
            self.layers += [
                nn.Sequential(
                    CNNBlock(channels,channels//2,kernel_size=1), # (N, C, H, W) -> (N, C//2, H, W)
                    CNNBlock(channels//2,channels,kernel_size=3,padding=1) # (N, C//2, H, W) -> (N, C, H, W)
            )]
        self.use_residual = use_residual
        self.numrepeats = num_repeats

    def forward(self,x):
        for layer in self.layers:

            # ALL 4 DIMENSIONS SHOULD MATCH FOR ADDITION : (N, C, H, W)
            x = layer(x) + x if self.use_residual else layer(x)

        return x

In [54]:
class ScalePrediction(nn.Module):
    def __init__(self,in_channels,num_classes):
        super(ScalePrediction,self).__init__()
        self.pred = nn.Sequential(
            CNNBlock(in_channels,2*in_channels,kernel_size=3,padding=1),

            # (num_classes + 5) -> (num_classes,proability of bbox,x,y,w,h)
            # 3*(num_classes + 5) -> for each of the 3 anchor boxes
            CNNBlock(2*in_channels,3*(num_classes+5),bn_act=False,kernel_size=1)
        )
        self.num_classes = num_classes

    def forward(self,x):

        ##############################
        # FOR FIRST SCALE PREDICTION #
        ##############################
        # when the YOLOv3 netwrok reaches scale prediction for first time;
        # self.pred(x) returns (N,C,H,W) which is (N,3*(num_classes+5),13,13)
        # rehape (Batch,3 Anchor Boxes,num_classes+5 for elements of each Anchor box)
        # .permute(0,1,3,4,2) -> change the position of dimensions to the way we want it
        
        # So final shape: N,3,13,13,5+num_classes
        # N is batch size
        # 3 Anchor boxes
        # 13 is grid size (S)
        # 5+num_classes are elements fro each bbox prediction
        return self.pred(x).reshape(x.shape[0],3,self.num_classes+5,x.shape[2],x.shape[3]).permute(0,1,3,4,2)

In [59]:
class YOLOv3(nn.Module):
    def __init__(self,in_channels=3,num_classes=20):
        super(YOLOv3,self).__init__()
        self.num_classes = num_classes
        self.in_channels=in_channels
        self.layers = self._create_conv_layers()
    
    def forward(self,x):
        outputs = []
        root_connections = []

        for layer in self.layers:
            if isinstance(layer,ScalePrediction):
                outputs.append(layer(x))
                continue
            
            x = layer(x)

            if isinstance(layer,ResidualBlock) and layer.numrepeats == 8:
                root_connections.append(x)

            elif isinstance(layer,nn.Upsample):


                # output from Residual block -> (N, C1, H, W)
                # output from upscaling -> (N, C2, H, W)
                # After concatenation -> (N,C1+C2,H,W)
                print(x.shape, root_connections[-1].shape)
                x = torch.cat([x,root_connections[-1]],dim=1)
                root_connections.pop() # remove the last one ie, the latest addition to the list

        return outputs

    def _create_conv_layers(self):
        layers = nn.ModuleList()
        in_channels = self.in_channels

        for module in config:
            if type(module) == tuple:
                out_channels,kernel_size,stride = module
                layers += [
                    CNNBlock(in_channels,
                             out_channels,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=1 if kernel_size == 3 else 0)
                ]
                in_channels=out_channels

            if type(module) == list:
                repeat = module[1]
                layers += [
                    ResidualBlock(in_channels,use_residual=True,num_repeats=repeat)
                ]

            if type(module) == str and module == "S":
                layers += [
                    ResidualBlock(in_channels,use_residual=False,num_repeats=1),
                    CNNBlock(in_channels,in_channels//2,kernel_size=1),
                    ScalePrediction(in_channels//2,num_classes=self.num_classes)
                ]

                in_channels = in_channels//2
            
            if type(module) == str and module == "U":
                layers += [
                    nn.Upsample(scale_factor=2),
                    
                ]
                in_channels = in_channels * 3

        return layers
    
if __name__ == "__main__":
    num_classes = 20
    IMAGE_SIZE = 416
    model = YOLOv3(num_classes=num_classes)
    x = torch.randn((2, 3, IMAGE_SIZE, IMAGE_SIZE))
    out = model(x)
    assert model(x)[0].shape == (2, 3, IMAGE_SIZE//32, IMAGE_SIZE//32, num_classes + 5)
    assert model(x)[1].shape == (2, 3, IMAGE_SIZE//16, IMAGE_SIZE//16, num_classes + 5)
    assert model(x)[2].shape == (2, 3, IMAGE_SIZE//8, IMAGE_SIZE//8, num_classes + 5)
    print("Success!")

torch.Size([2, 256, 26, 26]) torch.Size([2, 512, 26, 26])
torch.Size([2, 128, 52, 52]) torch.Size([2, 256, 52, 52])
torch.Size([2, 256, 26, 26]) torch.Size([2, 512, 26, 26])
torch.Size([2, 128, 52, 52]) torch.Size([2, 256, 52, 52])
torch.Size([2, 256, 26, 26]) torch.Size([2, 512, 26, 26])
torch.Size([2, 128, 52, 52]) torch.Size([2, 256, 52, 52])
torch.Size([2, 256, 26, 26]) torch.Size([2, 512, 26, 26])
torch.Size([2, 128, 52, 52]) torch.Size([2, 256, 52, 52])
Success!
