# Putting it All Together Without Loss

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [2]:
from typing import Type, List
import torch.nn as nn
from torch import Tensor

In [3]:
from model.backbone.retina_meta import RetinaNetFPN50, RetinaNetHead
from model.anchor_generator import AnchorBoxGenerator
from model.backbone.resnet import ResNet50
from utils.shape_utils import permute_to_N_HWA_K

In [4]:
class RetinaNet(nn.Module):
    def __init__(self, 
                 base: Type[nn.Module], 
                 backbone: Type[nn.Module], 
                 head: Type[nn.Module], 
                 anchor_generator: AnchorBoxGenerator, 
                 num_classes=20):
        
        super().__init__()
        self.base = base()
        self.backbone = backbone()
        self.head = head(num_classes)
        self.anchor_generator = anchor_generator
        self.num_classes = num_classes
        
    def forward(self, x):
        _, C3, C4, C5 = self.base(x)
        P3, P4, P5, P6, P7 = self.backbone(C3, C4, C5)
        
        pred_logits, pred_bboxes = self.head(P3, P4, P5, P6, P7)
        
        anchors = self.anchor_generator([P3, P4, P5, P6, P7])
        
        reshaped_logits = [
            permute_to_N_HWA_K(pred_logits[k], self.num_classes) for k in pred_logits
        ]
        
        reshaped_bboxes = [
            permute_to_N_HWA_K(pred_bboxes[k], 4) for k in pred_bboxes
        ]
        
        return reshaped_logits, reshaped_bboxes, anchors

## Test Model

In [2]:
import torch

In [3]:
num_classes = 20
num_anchors = 9

In [4]:
data = torch.randn((16, 3, 512, 512))

In [8]:
anchor_gen = AnchorBoxGenerator(
    sizes=[32., 64., 128., 256., 512.],
    aspect_ratios=[0.5, 1., 2.],
    scales=[1., 2 ** (1 / 3), 2 ** (2 / 3)],
    strides=[2, 2, 2, 2, 2]
)

In [9]:
model = RetinaNet(ResNet50, RetinaNetFPN50, RetinaNetHead, anchor_gen, num_classes=num_classes)

In [None]:
pred_logits, pred_bboxes, _ = model(data)

In [None]:
assert pred_logits[0].shape == (16, 64 * 64 * num_anchors, num_classes)

In [None]:
assert pred_bboxes[0].shape == (16, 64 * 64 * num_anchors, 4)

In [5]:
from model.model import RetinaNet500

In [6]:
model = RetinaNet500(num_classes)

In [None]:
pred_logits, pred_anchor_deltas, anchors = model(data)

In [None]:
assert len(pred_logits) == len(pred_anchor_deltas) == 5
assert pred_logits[0].shape == (16, 64 * 64 * num_anchors, num_classes)
assert pred_anchor_deltas[0].shape == (16, 64 * 64 * num_anchors, 4)
assert len(anchors) == 5

## Check Model Variance and Mean and Play with INIT

In [31]:
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F

In [19]:
def stats(x): return x.mean(), x.std()

In [22]:
conv5_in = torch.randn((16, 256, 8, 8))

In [21]:
stats(model.model.backbone.conv5.weight), stats(model.model.backbone.conv5.bias)

((tensor(-2.0029e-05, grad_fn=<MeanBackward0>),
  tensor(0.0120, grad_fn=<StdBackward0>)),
 (tensor(0.0011, grad_fn=<MeanBackward0>),
  tensor(0.0123, grad_fn=<StdBackward0>)))

### Conv Layer W/O Activation

In [25]:
stats(model.model.backbone.conv5(conv5_in))

(tensor(0.0002, grad_fn=<MeanBackward0>),
 tensor(0.5302, grad_fn=<StdBackward0>))

In [27]:
init.kaiming_normal_(model.model.backbone.conv5.weight, a=1.)
stats(model.model.backbone.conv5(conv5_in))

(tensor(0.0009, grad_fn=<MeanBackward0>),
 tensor(0.9169, grad_fn=<StdBackward0>))

### Conv Layer W/ Activation (ReLU)

In [34]:
def f1(conv, x, a=0): return F.leaky_relu(conv(x), a)

In [35]:
init.kaiming_normal_(model.model.backbone.conv5.weight, a=0)
stats(f1(model.model.backbone.conv5, conv5_in))

(tensor(0.5144, grad_fn=<MeanBackward0>),
 tensor(0.7585, grad_fn=<StdBackward0>))

In [36]:
l1 = nn.Conv2d(256, 256, kernel_size=3, stride=1)

In [37]:
stats(f1(l1, conv5_in))

(tensor(0.2294, grad_fn=<MeanBackward0>),
 tensor(0.3355, grad_fn=<StdBackward0>))