**[Weight Initialization in Neural Networks: A Journey From the Basics to Kaiming](https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79)**    
As we showed before, keeping the standard deviation of layers’ activations around 1 will **allow us to stack several more layers** in a deep neural network without gradients exploding or vanishing

In [1]:
import numpy as np
import torch
import math

In [2]:
inp_dim = 28*28 # 28 pixels image
num_hidden = 784
inp_dim, num_hidden

(784, 784)

# Simulate 100 layers depth network without activations

## weights is too big -> exploding

In [3]:
x = torch.randn(inp_dim)
w = torch.randn(inp_dim, num_hidden)
x.shape, w.shape
for i in range(100):
    x = w @ x
x.std(), x.mean()

(tensor(nan), tensor(nan))

In [4]:
x = torch.randn(inp_dim)
w = torch.randn(inp_dim, num_hidden)
x.shape, w.shape
for i in range(100):
    x = w @ x
    if x.std() != x.std():
        print('Break after {} multiplications'.format(i))
        break

Break after 26 multiplications


## weights is too small -> vanishing

In [5]:
x = torch.randn(inp_dim)
w = torch.randn(inp_dim, num_hidden)*0.01 # weights init with std=0.01
x.shape, w.shape
for i in range(100):
    x = w @ x
x.std(), x.mean()

(tensor(0.), tensor(0.))

# Explore sweet spot

## element wise calculate explained - home grown method

In [6]:
mean, variance = 0., 0.
for i in range(100):
    x = torch.randn(1)
    w = torch.randn(1)
    y = w @ x
    mean += y
    variance += y**2
mean/100, variance/100

(tensor(0.0564), tensor(0.9571))

in order to calculate `y` we sum `inp_dim` products of the element-wise multiplication of one row of the weights `w` by one element of the inputs `x`  
each product have mean 0 and std 1 so total variance is `inp_dim`

In [7]:
mean, variance = 0., 0.
for i in range(100):
    x = torch.randn(inp_dim)
    w = torch.randn(num_hidden, inp_dim)
    y = w @ x
    mean += y.mean()
    variance += y.pow(2).mean()
mean/100, variance/100

(tensor(0.1593), tensor(791.2849))

## some method used before 2010

**no activations**   
- supposed to be linear (just matmul)  
- good result

In [8]:
mean, variance = 0., 0.
x = torch.randn(inp_dim)
for i in range(100):
    w = torch.randn(num_hidden, inp_dim)/math.sqrt(inp_dim)
    x = w @ x
x.mean(), x.std() # good mean and std

(tensor(-0.0017), tensor(0.6494))

**non-linear activations**
- bad result

In [9]:
# test with some common non-linear activations
def tanh(x):
    return torch.tanh(x)

def sigmoid(x):
    return torch.sigmoid(x)

In [10]:
x = torch.randn(inp_dim)
for i in range(100):
    w = torch.randn(num_hidden, inp_dim)/math.sqrt(inp_dim)
    x = tanh(w @ x)
#     x = sigmoid(w @ x)
#     x = relu(w @ x)
x.mean(), x.std()

(tensor(0.0025), tensor(0.0533))

**standard init approach before 2010**
- initializing weights from a uniform distribution in [-1,1] and then scaling by 1/√n
- bad result, just like vanish

In [11]:
x = torch.randn(inp_dim)
for i in range(100):
    w = torch.Tensor(num_hidden, inp_dim).uniform_(-1, 1) * math.sqrt(1./inp_dim)
    x = tanh(w @ x)
x.mean(), x.std()

(tensor(-3.9219e-26), tensor(1.0964e-24))

## Xavier (normalized) initialization 
[Landmark Origin Paper](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf?source=post_page---------------------------)
- Identical to home-grown method
![Formula](https://miro.medium.com/max/875/1*H6t3yYBLlinNRUwmL-d7vw.png)

In [12]:
def xavier_uniform_init(fan_in, fan_out):
    return torch.Tensor(fan_in, fan_out).uniform_(-1, 1) * math.sqrt(6./(fan_in + fan_out))

def xavier_normal_init(fan_in, fan_out):
    return torch.Tensor(fan_in, fan_out).normal_() * math.sqrt(2./(fan_in + fan_out))


In [13]:
# test with some Xavier method
x = torch.randn(inp_dim)
for i in range(100):
    w = xavier_uniform_init(inp_dim, num_hidden)
#     w = xavier_normal_init(inp_dim, num_hidden)
#     w = torch.nn.init.xavier_uniform_(torch.FloatTensor(inp_dim, num_hidden))
#     w = torch.nn.init.xavier_normal_(torch.FloatTensor(inp_dim, num_hidden))
    x = tanh(w @ x)
x.mean(), x.std()

(tensor(-0.0007), tensor(0.0585))

## Kaiming initialization
What if activation functions **asymmetric, non-linear activations** such as: **ReLU**

## ReLU inspection 
- variance~inp_dim/2

In [14]:
def relu(x):
    return torch.clamp(x, min=0)

In [15]:
# variance~784/2
mean, variance = 0., 0.
for i in range(100):
    x = torch.randn(inp_dim)
    w = torch.randn(num_hidden, inp_dim)
    y = relu(w @ x)
    mean += y.mean()
    variance += y.pow(2).mean()
mean/100, variance/100

(tensor(11.1279), tensor(389.5249))

### Try to use home-grown init method
- std ~ 1

In [16]:
mean, variance = 0., 0.
for i in range(100):
    x = torch.randn(inp_dim)
    w = torch.randn(num_hidden, inp_dim)*math.sqrt(2./inp_dim)
    y = relu(w @ x)
    mean += y.mean()
    variance += y.pow(2).mean()
mean/100, variance/100    

(tensor(0.5628), tensor(0.9925))

In their 2015 paper, He et. al. demonstrated that deep networks (e.g. a 22-layer CNN) would **converge much earlier** if the following input weight initialization strategy is employed:  
- Create a tensor with the dimensions appropriate for a weight matrix at a given layer, and populate it with numbers randomly chosen from a **standard normal distribution**.  
- Multiply each randomly chosen number by `√2/√n` where `n` is the number of **incoming connections** coming into a given layer from the previous layer’s output (also known as the **“fan-in”**).  
- **Bias** tensors are **initialized to zero**.

### Kaiming He vs Xavier on ReLU
- Xavier init: vanished in 100 layer depth
- Kaiming init: just good enough

In [18]:
def kaiming_init(inp_dim, num_hidden):
    return torch.randn(inp_dim, num_hidden)*math.sqrt(2./inp_dim)

In [21]:
# Kaiming He init
x = torch.randn(inp_dim)

for i in range(100):
    w = kaiming_init(inp_dim, num_hidden)
    x = relu(w @ x)
x.mean(), x.std()

(tensor(0.5421), tensor(0.7700))

In [23]:
# Xavier init -> vanished
x = torch.randn(inp_dim)

for i in range(100):
    w = xavier_normal_init(inp_dim, num_hidden)
    x = relu(w @ x)
x.mean(), x.std()

(tensor(7.1177e-16), tensor(1.0367e-15))