# Weight Initialization

In [1]:
import math
import numpy as np
import torch

In [2]:
x = torch.randn(512) # standard normal distribution with mean 0 and std 1.

In [3]:
print('Mean:',x.mean().item())
print('Std:',x.std().item())

Mean: -0.028838863596320152
Std: 0.9946354627609253


## Exploding Weights

In [4]:
x = torch.randn(512)

for i in range(100):
    a = torch.randn(512, 512)
    x = a @ x

In [5]:
print('Mean:',x.mean().item())
print('Std:',x.std().item())

Mean: nan
Std: nan


In [6]:
x = torch.randn(512)

for i in range(100):
    a = torch.randn(512, 512)
    x = a @ x
    
    if torch.isnan(x.std()): break
        
i

28

## Vanishing Weights

In [7]:
x = torch.randn(512)

for i in range(100):
    a = torch.randn(512, 512) * 0.01
    x = a @ x

In [8]:
print('Mean:',x.mean().item())
print('Std:',x.std().item())

Mean: 0.0
Std: 0.0


## Weights Multiplication in Neural Network

In [9]:
n_layers = 10000
n_inputs = 512

In [10]:
x = torch.randn(n_inputs)

mean, variance = 0., 0.
for i in range(n_layers):
    a = torch.randn(n_inputs, n_inputs)
    y = a @ x
    mean += y.mean().item() 
    variance += y.pow(2).mean().item()

In [11]:
print('Mean:',mean/ n_layers)
print('Std:',math.sqrt(variance/n_layers)) # STD will explode in this hypothetical network

Mean: -0.007916836410108954
Std: 22.787822588706245


In [12]:
math.sqrt(n_inputs)

22.627416997969522

## Weight Multiplication w/ Weight Initialization

In [13]:
n_layers = 10000
n_inputs = 512

In [14]:
x = torch.randn(n_inputs)

mean, variance = 0., 0.
for i in range(n_layers):
    a = torch.randn(n_inputs, n_inputs) * math.sqrt(1./512)
    y = a @ x
    mean += y.mean().item() 
    variance += y.pow(2).mean().item()

In [15]:
print('Mean:',mean/ n_layers)
print('Std:',math.sqrt(variance/n_layers)) # STD will be 1 in this hypothetical network with weight init

Mean: -0.0002587269434472546
Std: 0.9993272300424068


In [16]:
for i in range(100):
    a = torch.randn(n_inputs, n_inputs) * math.sqrt(1./512)
    x = a @ x

In [17]:
print('Mean:',x.mean().item())
print('Std:',x.std().item())

Mean: -0.01688116043806076
Std: 0.7799294590950012


## Weight Multiplication w/ Weight Initialization + Tanh Activation Function

In [18]:
n_inputs = 512

In [19]:
# non-linear activation function approximately describing real-world phenomena
def tanh(x): return torch.tanh(x)
x = torch.randn(n_inputs)

for i in range(100):
    a = torch.randn(n_inputs, n_inputs) * math.sqrt(1./512)
    x = tanh(a @ x)

In [20]:
print('Mean:',x.mean().item())
print('Std:',x.std().item())

Mean: -0.0011068356689065695
Std: 0.09315230697393417


## Standard Weight Init - BAD!

In [21]:
def standard_init(m, h): return torch.Tensor(m, h).uniform_(-1, 1) * math.sqrt(1./512)

In [22]:
n_layers = 100
n_inputs = 512

In [23]:
def tanh(x): return torch.tanh(x)
x = torch.randn(n_inputs)

for i in range(100):
    a = standard_init(n_inputs, n_inputs)
    x = tanh(a @ x)

In [38]:
print('Mean:',x.mean().item())
print('Std:',x.std().item()) # the std of layers' activations will cause almost completely vanishing gradients

Mean: 6.148139943250872e-16
Std: 9.401696196399494e-16


<img src='images/standard_init.png' width=75% />

## TanH - [Xavier](https://arxiv.org/pdf/1502.01852.pdf) Weight Init - GOOD!

<img src='images/xavier_formula.png' width=75% />

In [25]:
# for symmetric, non-linear activations
def xavier_init(m, h): return torch.Tensor(m, h).uniform_(-1, 1) * math.sqrt(6./ (m+h))

In [26]:
n_layers = 100
n_inputs = 512

In [27]:
def tanh(x): return torch.tanh(x)
x = torch.randn(n_inputs)

for i in range(n_layers):
    a = xavier_init(n_inputs, n_inputs)
    x = tanh(a @ x)

In [28]:
print('Mean:',x.mean().item())
print('Std:',x.std().item()) # the std of layers’ activations around 1 will avoid gradients exploding or vanishing

Mean: -0.0012303171679377556
Std: 0.048384666442871094


<img src='images/xavier_init.png' width=75% />

## ReLU - Xavier Weight Init - BAD!

In [35]:
n_layers = 100
n_inputs = 512

In [36]:
def relu(x): return torch.clamp(input=x, min=0.)
x = torch.randn(n_inputs)

for i in range(n_layers):
    a = xavier_init(n_inputs, n_inputs)
    x = relu(a @ x)

In [37]:
print('Mean:',x.mean().item())
print('Std:',x.std().item()) # the std of layers' activations will cause almost completely vanishing gradients

Mean: 6.148139943250872e-16
Std: 9.401696196399494e-16


## ReLU - [Kaiming](http://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) Weight Init - GOOD!

In [29]:
# for asymmetric, non-linear activations
def kaiming_init(m, h): return torch.randn(m,h) * math.sqrt(2./m)

In [30]:
n_layers = 100
n_inputs = 512

In [31]:
def relu(x): return torch.clamp(input=x, min=0.)
x = torch.randn(n_inputs)

for i in range(n_layers):
    a = kaiming_init(n_inputs, n_inputs)
    x = relu(a @ x)

In [32]:
print('Mean:',x.mean().item())
print('Std:',x.std().item()) # the std of layers’ activations around 1 will avoid gradients exploding or vanishing

Mean: 0.699047327041626
Std: 0.9852362275123596


## Courtesy

> [Weight Initialization in Neural Networks](https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79)

---