# Weight Initialization in Neural Networks: A Journey From the Basics to Kaiming

## Objective

The objective of this notebook is to follow [this](https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79) article as it demonstrates weight initialization in neural networks, the issues associated with it and solutions to said issues.

## Background

* Neural networks are made of a layered architecture with each layer containing a set of weights, biases and an activation function

* Input to a layer is first multiplied with the weights, then biases are added to the product and finally it is passed through the activation function

* At the start of the training phase for an NN, weights and biases need to be initialized to any value, from where the training then corrects these values to reduce prediction error

    * Traditionally, weights are initialized by sampling from a standard normal distribution, while the biases are set to 1

    * However, this initialization can lead to exploding gradients (output from the layer becomes infinite or nan) or vanishing gradients (output becomes 0), which can prevent the training from converging to an optima

In [10]:
import torch
import math

Vector of inputs

In [2]:
x = torch.randn(512)

Simulation of 100 layer neural network

In [4]:
for i in range(100):
    weight = torch.randn(512, 512)
    x = weight @ x
x.mean(), x.std()

(tensor(nan), tensor(nan))

During the forward pass through 100 layers, network weights got close to infinity or nans

In [5]:
x = torch.randn(512)

for i in range(100):
    weight = torch.randn(512, 512)
    x = weight @ x
    if torch.isnan(x.std()):
        break
i

27

Weights exploded within 28 layers of the network.

For vanishing gradients, we are sampling from normal distribution with mean 0 but scaling to get standard deviation 0.01:

In [7]:
x = torch.randn(512)

for i in range(100):
    weight = torch.randn(512, 512) * 0.01
    x = weight @ x
    
x.mean(), x.std()

(tensor(0.), tensor(0.))

The matrix product of an input $x$ and weight matrix $weight$ initialized from a standard normal distribution will have standard deviation close to square root of number of input connections.

In [11]:
mean, var = 0.0, 0.0

for i in range(1000):
    x = torch.randn(512)
    weight = torch.randn(512, 512)
    y = weight @ x
    mean += y.mean().item()
    var += y.pow(2).mean().item()

In [12]:
mean / 1000, math.sqrt(var / 1000)

(0.002180554822087288, 22.629027893656087)

In [13]:
math.sqrt(512)

22.627416997969522

This is because matrix multiplication in this example is a sum of 512 products of elements of $x$ and $weight$. Since both $x$ and $weight$ are initialized from a standard normal distribution, the 512 products would have a mean 0 and standard deviation 1

In [15]:
mean, var = 0.0, 0.0

for i in range(10000):
    x = torch.randn(1)
    a = torch.randn(1)
    y = a * x
    mean += y.item()
    var += y.pow(2).item()
    
mean / 10000, math.sqrt(var / 10000)

(0.005823576229463287, 0.9878580677864147)

In our example, we would like to have each layer's outputs to have standard deviation of 1, which would prevent exploding gradients. This will be done by scaling with $\frac{1}{\sqrt{512}}$

In [19]:
mean, var = 0.0, 0.0

for i in range(1000):
    x = torch.randn(512)
    a = torch.randn(512, 512) * math.sqrt(1. / 512)
    y = a @ x
    mean += y.mean().item()
    var += y.pow(2).mean().item()
    
mean / 1000, var / 1000

(0.0011965707547497003, 0.9980551573634148)

Doing this for the 100 layer neural network.

In [20]:
x = torch.randn(512)

for i in range(100):
    weight = torch.randn(512, 512) * math.sqrt(1. / 512)
    x = weight @ a
    
x.mean(), x.std()

(tensor(-7.5199e-05), tensor(0.0444))

Using the traditional symmetric activation functions.

In [21]:
def tanh(x):
    return torch.tanh(x)

In [22]:
x = torch.randn(512)

for i in range(100):
    weight = torch.randn(512, 512) * math.sqrt(1. / 512)
    x = tanh(weight @ x)
    
x.mean(), x.std()

(tensor(-0.0021), tensor(0.0763))

### Xavier Initialization

* $$\pm\frac{\sqrt{6}}{\sqrt{n_{i} + n_{i + 1}}}$$

    * $n_{i}$ is number of incoming network connections (fan-in) to the layer and $n_{i+1}$ are number of outgoing connections.

In [23]:
def xavier(m, h):
    return torch.Tensor(m, h).uniform_(-1, 1) * math.sqrt(6./(m+h))

In [24]:
x = torch.randn(512)

for i in range(100):
    weight = xavier(512, 512)
    x = tanh(weight @ x)
    
x.mean(), x.std()

(tensor(0.0008), tensor(0.0868))

Trying the same for ReLU

In [25]:
def relu(x):
    return x.clamp_min(0.)

In [26]:
mean, var = 0., 0.

for i in range(1000):
    x = torch.randn(512)
    weight = torch.randn(512, 512)
    y = relu(weight @ x)
    
    mean += y.mean().item()
    var += y.pow(2).mean().item()
    
mean / 1000, math.sqrt(var / 1000)

(9.016544756889344, 15.992926274173755)

In [28]:
math.sqrt(512) / math.sqrt(2)

16.0

Scaling by this number would get individual ReLU layers to have standard deviation of 1

In [29]:
mean, var = 0., 0.

for i in range(1000):
    x = torch.randn(512)
    weight = torch.randn(512, 512) * math.sqrt(2 / 512)
    y = relu(weight @ x)
    
    mean += y.mean().item()
    var += y.pow(2).mean().item()
    
mean / 1000, math.sqrt(var / 1000)

(0.5638078674376011, 0.9993184797611605)

### Kaiming Initialization

* Create a tensor with the dimensions appropriate for a weight matrix at a given layer, and populate it with numbers randomly chosen from a standard normal distribution.

* Multiply each randomly chosen number by √2/√n where n is the number of incoming connections coming into a given layer from the previous layer’s output (also known as the “fan-in”).

* Bias tensors are initialized to zero.

In [30]:
def kaiming(m, h):
    return torch.randn(m, h) * math.sqrt(2. / m)

In [31]:
x = torch.randn(512)

for i in range(100):
    weight = kaiming(512, 512)
    x = relu(weight @ x)

x.mean(), x.std()

(tensor(0.5415), tensor(0.7596))

In [32]:
x = torch.randn(512)

for i in range(100):
    weight = xavier(512, 512)
    x = relu(weight @ x)

x.mean(), x.std()

(tensor(5.7588e-16), tensor(8.4180e-16))