source: https://www.youtube.com/watch?v=n9_XyCGr-MI

Dataset YOLO was trained on 
- PASCAL VOC Dataset
- 20 Classes

In [1]:
import torch
import torch.nn as nn

In [2]:
torch.cuda.is_available()

True

In [3]:
architecture_config = [
    (7, 64, 2, 3), # (kernel_size, no_of filters, stride, padding)
    "M",           # maxpool
    (3, 192, 1, 1),
    "M",
    (1, 128, 1, 0),
    (3, 256, 1, 1),
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 256, 1, 0),(3, 512, 1, 1), 4],
    (1, 512, 1, 0),
    (3, 1024, 1, 1),
    "M",
    [(1, 512, 1, 0),(3, 1024, 1, 1), 2],
    (3, 1024, 1, 1),
    (3, 1024, 2, 1),
    (3, 1024, 1, 1),
    (3, 1024, 1, 1),   
]

In [4]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.batchnorm = nn.BatchNorm2d(out_channels)
        self.leakyrelu = nn.LeakyReLU(0.1)
        
    def forward(self, x):
        return self.leakyrelu(self.batchnorm(self.conv(x)))

In [51]:
class Yolov1(nn.Module):
    def __init__(self, in_channels=3, **kwargs):
        super().__init__()
        self.architecture = architecture_config
        self.in_channels = in_channels
        self.darknet = self._create_conv_layers(self.architecture)
        self.fcs = self._create_fcs(**kwargs)
    
    def forward(self, x):
        x = self.darknet(x)
        print(x.shape, "flattened to :")
        x = torch.flatten(x, start_dim=1)
        print(x.shape, ":")
        return self.fcs(x)
    
    def _create_conv_layers(self, architecture):
        """ This is darknet """
        layers = []
        in_channels = self.in_channels
        
        for x in architecture:
            if type(x) == tuple:
                layers += [CNNBlock(
                    in_channels, x[1], kernel_size=x[0], stride=x[2], padding=x[3]
                )]
                
                in_channels = x[1]
                
            elif type(x) == str:
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                
            elif type(x) == list:
                conv1 = x[0] # Tuple
                conv2 = x[1] # Tuple
                num_repeats = x[2] # Integer
                
                for _ in range(num_repeats):
                    layers += [
                        CNNBlock(
                            in_channels,
                            conv1[1],
                            kernel_size=conv1[0],
                            stride= conv1[2],
                            padding=conv1[3]
                        )
                    ]
                    
                    layers += [
                        CNNBlock(
                            conv1[1],
                            conv2[1],
                            kernel_size=conv2[0],
                            stride= conv2[2],
                            padding=conv2[3]
                        )
                    ]
                    
                    in_channels = conv2[1]
                   
        return nn.Sequential(*layers)
    
    def _create_fcs(self, split_size, num_boxes, num_classes):
        S , B, C = split_size, num_boxes, num_classes # 7, 2, 20
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(1024 * S * S, 496), # original paper = 4096
            nn.Dropout(0.0),
            nn.LeakyReLU(0.1),
            nn.Linear(496, S *S *(C + B * 5)), # later reshaped to (S, S, 30)  C+B*5 =30 in loss fn
        )
    

In [57]:
def test(S=7, B=2, C=20):
    model = Yolov1(split_size=S, num_boxes=B, num_classes=C)
    x = torch.randn((2, 3, 448, 448))
    op = model(x)
    print(op)

# test()

# Loss fn 

In [58]:
class YoloLoss(nn.Module):
    def __init__(self, S=7, B=2, C=20):
        super().__init__()
        self.mse = nn.MSELoss(reduction="sum")
        self.S = S
        self.B = B
        self.C = C
        self.lambda_noobj = 0.5
        self.lambda_coord = 5
        
    def forward(self, predictions, target):
        predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B*5)
        
        iou_b1 = intersection_over_union(predictions[..., 21:25], target[..., 21:25])
        iou_b2 = intersection_over_union(predictions[..., 26:30], target[..., 21:25])
        
        ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)
        iou_maxes, best_box = torch.max(ious, dim=0)
        
        exists_box = target[..., 20].unsqueeze(3) # identity object_i (is there obj in cell i)
        
        # for box coordinates
        box_predictions = exists_box * (
            (
                best_box * predictions[..., 26:30]
                + (1- best_box) * predictions[..., 21:25]
            )
        )
        
        box_targets = exists_box * target[..., 21:25]
        
        box_predictions[..., 2:4] = torch.sign(box_predictions[...,2:4]) * torch.sqrt(torch.abs(box_predictions[..., 2:4] + 1e-6))
        
        #(N, S, S, 25)
        box_targets[..., 2:4] = torch.sqrt(box_targets[..., 2:4])
        
        #(N, S, S, 4) -> (N *S*S, 4)
        box_loss = self.mse(
            torch.flatten(box_predictions, end_dim=-2),
            torch.flatten(box_targets, end_dim=-2)
        )
        
        # for object loss
        
        pred_box = (
            best_box * predictions[..., 25:26] + (1 - best_box) * predictions[..., 20: 21]
        )
        
        # (N *S*S)
        object_loss = self.mse(
            torch.flatten(exists_box * pred_box), 
            torch.flatten(exists_box * target[..., 20:21]), 
        )
        
        # for no object loss
        
         # (N *S*S , 1) ->  # (N, S*S)
        no_object_loss = self.mse(
            torch.flatten((1-exists_box) * predictions[...,20:21], start_dim=1),
            torch.flatten((1-exists_box) * target[...,20:21], start_dim=1)
        )
        
        no_object_loss += self.mse(
            torch.flatten((1-exists_box) * predictions[...,25:26], start_dim=1),
            torch.flatten((1-exists_box) * target[...,20:21], start_dim=1)
        )
        
        # for class loss
        
        # (N ,S, S , 20) ->  # (N*S*S, 20)
        class_loss = self.mse(
            torch.flatten(exists_box * predictions[..., :20], end_dim=-2),
            torch.flatten(exists_box * target[..., :20], end_dim=-2),
        )
        
        loss = (
            self.lambda_coord * box_loss 
            + object_loss
            + self.lambda_noobj * no_object_loss
            + class_loss
        )
        
        return loss

In [None]:
def intersection_over_union():
    