In [1]:
%load_ext autoreload
%autoreload 2

In [54]:
import torch
import torch.nn as nn
import torchvision
from torchvision.models import resnet18, resnet34
import sys
sys.path.append('../')
import dataset

In [55]:
pascal_voc_train = torchvision.datasets.VOCDetection(
    root="../data",
    year="2007",
    image_set="train",
    download=False
)

In [116]:
voc_train = dataset.PascalVOC(pascal_voc=pascal_voc_train)

TRANSFORMING PASCAL VOC


In [243]:
class ResNet18YOLOv1(nn.Module):
    def __init__(self, S=7, B=2, C=20):
        super().__init__()
        self.S = S
        self.B = B
        self.C = C
        self.resnet = self.init_resnet()
        self.fc = nn.Sequential(
            nn.Linear(7 * 7 * 512, 4096),
            nn.Linear(4096, self.S**2 * (5 * self.B + self.C))
        )
        
    def init_resnet(self):
        resnet = resnet18(weights="IMAGENET1K_V1") 
        
        # replace relu with leaky relu
        resnet = self.replace_with_leaky_relu(resnet)
        
        # remove feedforward layer
        named_children = resnet.named_children()
        layers_to_remove = set(["fc", "avgpool"])
        layers = [module for name, module in named_children if name not in layers_to_remove]
        
        # add a conv layer at the end to reduce feature map to (512, 7, 7)
        layers.append(nn.Conv2d(512, 512, kernel_size=2, stride=2))
        
        return nn.Sequential(*layers)
        
    def replace_with_leaky_relu(self, nn_module):     
        named_children = nn_module.named_children()
        
        # loop over immediate children modules
        for name, module in named_children:
            is_relu = isinstance(module, nn.ReLU)
            
            if is_relu:
                leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
                setattr(nn_module, name, leaky_relu)
            else:
                self.replace_with_leaky_relu(module)
                
        return nn_module
        
    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = x.view(-1, self.S, self.S, 5 * self.B + self.C)
        
        return x

In [244]:
yolo = ResNet18YOLOv1()
yolo

ResNet18YOLOv1(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.1, inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-

In [245]:
img = voc_train[0][0]
img = img.unsqueeze(0)
img.shape

torch.Size([1, 3, 448, 448])

In [246]:
out = yolo(img)
out.shape, out

(torch.Size([1, 7, 7, 30]),
 tensor([[[[-0.6154, -0.1316,  0.0526,  ...,  0.1136, -0.0582, -0.1567],
           [ 0.3633,  0.2270, -0.1410,  ..., -0.2798, -0.1981,  0.2558],
           [-0.1685,  0.1644, -0.8382,  ..., -0.5201, -0.1070, -0.2146],
           ...,
           [-0.1184,  0.0367,  0.0250,  ...,  0.2513, -0.0289,  0.3806],
           [ 0.1446,  0.0924,  0.2452,  ...,  0.1709,  0.3458, -0.3615],
           [ 0.3931, -0.0309,  0.1540,  ...,  0.0830,  0.3806,  0.6287]],
 
          [[-0.7503,  0.2484, -0.4309,  ...,  0.2381,  0.0977,  0.3878],
           [ 0.4302, -0.1472,  0.3353,  ..., -0.0331, -0.0993,  0.4362],
           [ 0.2035, -0.0211,  0.2381,  ..., -0.1506, -0.5608, -0.3045],
           ...,
           [-0.1417,  0.0473, -0.0132,  ..., -0.0146, -0.1674, -0.2383],
           [ 0.1072, -0.1049, -0.2156,  ..., -0.1790,  0.3378, -0.0987],
           [-0.3682, -0.1965,  0.1584,  ...,  0.0257,  0.1124,  0.6327]],
 
          [[ 0.0234,  0.0463, -0.1522,  ...,  0.2997,  0.1