In [None]:
import numpy as np
import math
import torch
from torch import nn
from matplotlib import pyplot as plt

In [None]:
from mnist_dataset import get_data, normalize


# MNIST Prepare
- load, split data
- normalize base on x_train mean & std

In [None]:
x_train,y_train,x_valid,y_valid = get_data()
train_mean,train_std = x_train.mean(),x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [None]:
x_train = x_train.view(-1,1,28,28)
x_valid = x_valid.view(-1,1,28,28)
x_train.shape, x_valid.shape

In [None]:
plt.imshow(x_train[0][0])

In [None]:
num_sample, *inp_shape = x_train.shape
num_class = y_train.max() + 1
num_hidden = 32
num_sample, num_class, inp_shape

# Simple conv2d

In [None]:
def stats(x):
    return x.mean(), x.std()

In [None]:
x = x_valid[:100]
x.shape, stats(x)

In [None]:
layer1 = nn.Conv2d(1, num_hidden, 5)
layer1.weight.shape, layer1.bias.shape

## layer1 init behind the scence
init.kaiming_uniform_(self.weight, ***a=math.sqrt(5)***)

In [None]:
nn.Conv2d??

In [None]:
torch.nn.modules.conv._ConvNd??

In [None]:
torch.nn.modules.conv._ConvNd.reset_parameters??

In [None]:
stats(layer1.weight), stats(layer1.bias)

## stats(x) after layer1

In [None]:
# std after layer 1 is not 1
stats(x), stats(layer1(x))

## try to use kaiming for leaky ReLU

In [None]:
import torch.nn.functional as F

In [None]:
def f1(x, a=0):
    return F.leaky_relu(layer1(x), a)

kaiming normal init:
- std is good

In [None]:
nn.init.kaiming_normal_(layer1.weight, a=0)
stats(x), stats(f1(x))

default Pytorch init
- std is not good

In [None]:
layer1 = nn.Conv2d(1, num_hidden, 5)
stats(x), stats(f1(x))

## Reimplement Kaiming Init for conv layer

In [None]:
layer1.weight.shape

In [None]:
# number of elements in one filter
receptive_field = layer1.weight[0,0].numel()
receptive_field

In [None]:
num_filter_out, num_filter_in, *_ = layer1.weight.shape
num_filter_out, num_filter_in

In [None]:
fan_in = num_filter_in*receptive_field
fan_out = num_filter_out*receptive_field
fan_in, fan_out

Kaiming init for leaky ReLU

In [None]:
def gain(a):
    """
    @ a: leaky part, a=sqrt(5) to handle uniform distribution of (-1,1)
            
    """
    return math.sqrt(2./(1 + a**2))

In [None]:
# linear, normal ReLU, .., .., pytorch origin
gain(1), gain(0), gain(0.01), gain(0.1), gain(math.sqrt(5.))

In [None]:
# uniform distribution check
torch.zeros(10000).uniform_(-1,1).std(), 1/math.sqrt(3)

In [None]:
# our kaiming init
def kaiming2(x, a, mode='fan_in'):
    num_filter_out, num_filter_in, *_ = x.shape
    receptive_field = x[0,0].numel()
    if mode=='fan_in':
        fan = num_filter_in*receptive_field
    else:
        fan = num_filter_out*receptive_field
    std = gain(a)/math.sqrt(fan)
    bound = math.sqrt(3.)*std
    x.data.uniform_(-bound, bound)

In [None]:
kaiming2(layer1.weight, a=0)
stats(f1(x))

In [None]:
# compare with pytorch origin init
nn.init.kaiming_uniform_(layer1.weight, a=math.sqrt(5))
stats(f1(x))

## Simple 4 layer conv for compare

In [None]:
class Flatten(nn.Module):
    def forward(self, x):
        return x.view(-1)

In [None]:
model = nn.Sequential(
        nn.Conv2d(1,8,5,stride=2,padding=2), nn.ReLU(),
        nn.Conv2d(8,16,3,stride=2,padding=1), nn.ReLU(),
        nn.Conv2d(16,32,3,stride=2,padding=1), nn.ReLU(),
        nn.Conv2d(32,1,3,stride=2,padding=1),
        nn.AdaptiveAvgPool2d(1),
        Flatten(),
)

In [None]:
y = y_valid[:100].float()

In [None]:
t = model(x)
# without init forward pass input and output same scale
stats(x), stats(t)

In [None]:
# backward pass is not same scale
loss = nn.MSELoss()
out = loss(t, y)
out.backward()
stats(model[0].weight.grad)

is kaiming init good?

In [None]:
for layer in model:
    if isinstance(layer, nn.Conv2d):
        nn.init.kaiming_uniform_(layer.weight)
        layer.bias.data.zero_()

In [None]:
# forward pass is better than without init
t = model(x)
stats(x), stats(t)

In [None]:
# backward pass is better than without init
loss = nn.MSELoss()
out = loss(t, y)
out.backward()
stats(model[0].weight.grad)