# Experiment Different Inits for Models

In [5]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [6]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

In [7]:
from model.backbone.retina_meta import RetinaNetFPN50, RetinaNetHead
from model.backbone.resnet import ResNet50
from utils.weight_init import c2_msra_fill

## Superhuman Init
paper: [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/abs/1502.01852)

In [8]:
def stats(x): return f"mean of {x.data.mean()}; std of {x.data.std()}"

## ResNet Init

In [5]:
arch = ResNet50()

In [6]:
data = torch.randn(32, 3, 512, 512)
stats(data)

'mean of 7.630876643816009e-05; std of 0.9998598694801331'

In [7]:
_, C3, C4, C5 = arch(data)

In [8]:
stats(C3), stats(C4), stats(C5)

('mean of 1.2017172574996948; std of 1.286324381828308',
 'mean of 1.5148167610168457; std of 1.5278385877609253',
 'mean of 1.0509231090545654; std of 1.1662362813949585')

In [47]:
def simple_kaiming(model: nn.Module, a=0):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal_(m.weight, a=a, mode="fan_in", nonlinearity="relu")
            if m.bias is not None:
                init.zeros_(m.bias)
                
def uniform_kaiming(model: nn.Module, a=0):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_uniform_(m.weight, a=a, mode="fan_in", nonlinearity="relu")
            if m.bias is not None:
                init.zeros_(m.bias)
                
                

        
def advanced_kaiming(model: nn.Module, a=0):
    for m in model.modules():
        if isinstance(m, nn.Conv2d): 
            nn.init.kaiming_normal_(m.weight.data, a=a, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(m.weight.data)
                bound = 1 / math.sqrt(fan_out)
                nn.init.normal_(m.bias, -bound, bound)

In [10]:
init.kaiming_normal_??

[0;31mSignature:[0m [0minit[0m[0;34m.[0m[0mkaiming_normal_[0m[0;34m([0m[0mtensor[0m[0;34m,[0m [0ma[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mmode[0m[0;34m=[0m[0;34m'fan_in'[0m[0;34m,[0m [0mnonlinearity[0m[0;34m=[0m[0;34m'leaky_relu'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
[0;32mdef[0m [0mkaiming_normal_[0m[0;34m([0m[0mtensor[0m[0;34m,[0m [0ma[0m[0;34m=[0m[0;36m0[0m[0;34m,[0m [0mmode[0m[0;34m=[0m[0;34m'fan_in'[0m[0;34m,[0m [0mnonlinearity[0m[0;34m=[0m[0;34m'leaky_relu'[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""Fills the input `Tensor` with values according to the method[0m
[0;34m    described in `Delving deep into rectifiers: Surpassing human-level[0m
[0;34m    performance on ImageNet classification` - He, K. et al. (2015), using a[0m
[0;34m    normal distribution. The resulting tensor will have values sampled from[0m
[0;34m    :math:`\mathcal{N}(0, \text{std}^2)` where[0m

In [11]:
stats(data)

'mean of 7.630876643816009e-05; std of 0.9998598694801331'

In [12]:
simple_kaiming(arch)

In [13]:
_, C3, C4, C5 = arch(data)

In [14]:
stats(C3), stats(C4), stats(C5)

('mean of 1.54548180103302; std of 1.6600478887557983',
 'mean of 1.9610862731933594; std of 2.0283734798431396',
 'mean of 1.7808983325958252; std of 2.0884199142456055')

In [15]:
advanced_kaiming(arch)

In [16]:
_, C3, C4, C5 = arch(data)

In [18]:
stats(C3)

'mean of 1.3309593200683594; std of 1.4333189725875854'

In [21]:
stats(data)

'mean of 7.630876643816009e-05; std of 0.9998598694801331'

In [24]:
arch = ResNet50()

In [25]:
out = arch.stem(data)
stats(out)

'mean of 1.039498209953308; std of 0.6766116619110107'

In [27]:
simple_kaiming(arch.stem)
out = arch.stem(data)
stats(out)

'mean of 1.0428764820098877; std of 0.6753635406494141'

In [43]:
arch = ResNet50()
out = arch.stem.conv1_1(data)
stats(out)

'mean of -0.00012122180487494916; std of 0.5770288109779358'

In [44]:
arch = ResNet50()
simple_kaiming(arch.stem.conv1_1)
out = F.relu(arch.stem.conv1_1(data))
stats(out)

'mean of 0.5831742286682129; std of 0.8645067811012268'

In [48]:
arch = ResNet50()
advanced_kaiming(arch.stem.conv1_1)
out = F.relu(arch.stem.conv1_1(data))
stats(out)

'mean of 0.5521239042282104; std of 0.8213089108467102'

In [46]:
arch = ResNet50()
uniform_kaiming(arch.stem.conv1_1)
out = F.relu(arch.stem.conv1_1(data))
stats(out)

'mean of 0.5679893493652344; std of 0.8348188996315002'

In [54]:
from utils.weight_init import c2_msra_fill, c2_xavier_fill

In [50]:
def custom_init(model: nn.Module, init_func):
    for m in model.modules():
        init_func(m)

In [57]:
arch = ResNet50()
custom_init(arch.stem.conv1_1, c2_msra_fill)
out = F.relu(arch.stem.conv1_1(data))
stats(out)

'mean of 0.1722812056541443; std of 0.25469303131103516'

In [63]:
target = torch.randn((*out.shape))
stats(target)

'mean of 6.653039599768817e-05; std of 0.9998788833618164'

In [65]:
loss = F.mse_loss(out, target)
loss

tensor(1.0941, grad_fn=<MseLossBackward>)

In [66]:
loss.backward()
stats(arch.stem.conv1_1.weight.grad)

'mean of 0.0001164769273600541; std of 0.0026100981049239635'

In [55]:
arch = ResNet50()
custom_init(arch.stem.conv1_1, c2_xavier_fill)
out = F.relu(arch.stem.conv1_1(data))
stats(out)

'mean of 0.3949778974056244; std of 0.5821265578269958'

## Test Pure Retina Inits

In [9]:
import random
import torch

In [10]:
from model.model import RetinaNet500
from model.loss import RetinaLoss

In [11]:
data = torch.randn((16, 3, 512, 512))
stats(data)

'mean of -0.000146355465403758; std of 1.0004196166992188'

In [12]:
objs = [random.randint(1, 7) for _ in range(16)]
labels = [torch.randint(0, 79, (num_o, 1)) for num_o in objs]
boxes = [torch.randn((num_o, 4)) for num_o in objs]
stats(boxes[0])

'mean of 0.34604740142822266; std of 0.6397870182991028'

In [13]:
loss = RetinaLoss()

In [14]:
model = RetinaNet500()

In [15]:
pred_logits, pred_bboxes, anchors = model(data)

In [16]:
stats(pred_logits[0])

'mean of -4.601955413818359; std of 0.025077205151319504'

In [17]:
stats(pred_logits[4])

'mean of -4.595100402832031; std of 0.004645262844860554'

In [18]:
losses = loss(pred_logits, pred_bboxes, anchors, boxes, labels)

In [19]:
losses['loss_cls'].item(), losses['loss_box_reg'].item()

(0.0, 0.0)

In [20]:
total_loss = losses['loss_cls'] + losses['loss_box_reg']

In [21]:
total_loss.backward()