In [10]:
import torch
import torch.nn as nn
from torchvision.models import vgg19_bn, resnet34, densenet161

from typing import Any, Callable, Optional, Tuple, List

Final shape of the output tensor will be $S \times S \times (B \ast 5 + C)$

In [2]:
class YOLOv1(nn.Module):
    def __init__(self, 
                 feature_extractor: Callable, 
                 num_grid: int = 7, 
                 num_bboxes: int = 2, 
                 num_classes: int = 80) -> None:
        super(YOLOv1, self).__init__()
        
        self.S = num_grid
        self.B = num_bboxes
        self.C = num_classes
        
        self.feature_extractor = feature_extractor
        self.conv_layers = self.create_conv_layers()
        self.fc_layers = self.create_fc_layers()
        
        # TODO: initialize weights manualy?

    def create_conv_layers(self) -> Callable:
        # We can assume that the input tensor to the first conv layer will look like [N, 512, 14, 14]
        # TODO: add batch_norm?
        conv = nn.Sequential(
            nn.Conv2d(512, 1024, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(1024, 1024, 3, stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Conv2d(1024, 1024, 3, padding=1),
            nn.LeakyReLU(0.1, inplace=True)
        )
        
        return conv
        
    def create_fc_layers(self) -> Callable:
        # We can assume that the input tensor to the first fc layer will look like [N, 512, 7, 7]
        fc = nn.Sequential(
            nn.Linear(7 * 7 * 1024, 4096),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Dropout(0.5, inplace=True),
            nn.Linear(4096, self.S * self.S * (5 * self.B + self.C)),
            nn.Sigmoid()
        )
        return fc
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.feature_extractor(x)
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        
        x = x.view(-1, self.S, self.S, 5 * self.B + self.C)
        return x

In [3]:
def create_model(feature_extractor: Callable) -> Callable:
    model = YOLOv1(feature_extractor)
    return model

In [4]:
feature_extractor = vgg19_bn(pretrained=True)
feature_extractor

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [5]:
model = create_model(feature_extractor.features)

X = torch.rand(2, 3, 448, 448)
y = model(X)
y.shape

torch.Size([2, 7, 7, 90])

In [None]:
feature_extractor = resnet34(pretrained=True, progress=False)
print(feature_extractor)
print()
feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-2])
feature_extractor

In [None]:
model = create_model(feature_extractor)

X = torch.rand(2, 3, 448, 448)
y = model(X)
y.shape

In [None]:
feature_extractor = densenet161(pretrained=True, progress=False)
feature_extractor

In [None]:
feature_extractor = feature_extractor.features
feature_extractor

In [None]:
model = create_model(feature_extractor)

X = torch.rand(2, 3, 448, 448)
y = model(X)
y.shape