In [42]:
import math
import torch

This is replication of the blog post by James Dellinger titled - *Weight Initialization in Neural Networks: A Journey From the Basics to Kaiming* (https://towardsdatascience.com/weight-initialization-in-neural-networks-a-journey-from-the-basics-to-kaiming-954fb9b47c79) 

The attempt here is to gain a first-hand understanding of the **challenge in training deep neural networks** and the **importance of proper weight initialization**

## Exploding gradients

- neural nets are eseentially a sequence of matrix multiplication operations ex. $Y = Sigmoid(ReLu(f_3(ReLu(f_2(ReLu(f_1(X))))))$ and this is only 3 layers deep

In [43]:
# the following initializes weights from a normal 0 mean, unit variance distribution
w = torch.randn(xdim,nodes)
w.mean(), w.std()

(tensor(-0.0033), tensor(1.0014))

In [85]:
def show_stats(x): print(x.mean(), x.std())

def fwd(dim, nlayers, var_scale=1):
    x = torch.randn(dim)
    for i in range(nlayers):
        w = torch.randn(dim,dim)*var_scale
        x = w@x
        if torch.isnan(x.mean()) or x.mean()==0:
            print('layer %d - mean: %f, std: %f'%(i,x.mean(),x.std())); return
    print('layer %d - mean: %f, std: %f'%(i,x.mean(),x.std()))

In [86]:
fwd(512, 100)

layer 27 - mean: nan, std: 79049849710233833933011818076318990336.000000


As the code snippet above demonstrates, the magnitude of $x$ in the forward pass explodes toward infinite with only 27 layers deep. Similarly, its gradients will also have exploded when performing the backward pass.

## Vanishing gradients

- conversely, we have the opposite problem when we attempt to initialize the weights too small
- below, lets see what happens when we initialize weights from a normal 0 mean, but 0.01 variance distribution

In [63]:
# the following initializes weights from a normal 0 mean, unit variance distribution
w = torch.randn(xdim,nodes) * 0.01
w.mean(), w.std()

(tensor(-2.9085e-05), tensor(0.0100))

In [87]:
fwd(512, 100, 0.01)

layer 65 - mean: -0.000000, std: 0.000000


### Magnitude of Std-dev

- the code snippet below demonstrates that the standard deviation of $y = wx$ (where both $w$ and $x$ have unit variance) is equal to $\sqrt{dim(x)}$

In [76]:
dim, mean, std, trials = 512, 0, 0, int(1e4)
for i in range(trials):
    w, x = torch.randn(dim,dim), torch.randn(dim)
    y = w@x
    mean += y.mean().item()
    std += y.std().item()
print('average mean: {}; average std-dev: {}'.format(mean/trials, std/trials))

average mean: 0.014554840754531324; average std-dev: 22.600583814239503


In [80]:
math.sqrt(dim)

22.627416997969522

- **as a result, it is good ideal to normalize $w$ initialization by $\sqrt{dim(x)}$ of input $x$**
- the code below validates this intuition, as we've made it through the entire forward pass of 100 layers and the output value remained closer to 0 mean and unit variance than before

### Activation functions

- modern neural net architectures typically contain non-linear activation functions after each hidden layer; it is this injection of non-linearity at each stage that allows neural nets its flexibility to act as an universal function approximator.
- however, when activations are included in the mix, the vanishing/exploding gradient problems start to re-emerge as demonstrated below.

In [258]:
def fwd_with_activation(dim, nlayers, var_scale=1, tanh=False, relu=False):
    x = torch.randn(dim)
    for i in range(nlayers):
        w = torch.randn(dim,dim)*var_scale
        if tanh: x = torch.tanh(w@x)
        elif relu: x = torch.relu(w@x)
        else: x = w@x
        if torch.isnan(x.mean()) or x.mean()==0:
            print('layer %d - mean: %f, std: %f'%(i,x.mean(),x.std())); return
    print('layer %d - mean: %f, std: %f'%(i,x.mean(),x.std()))

In [205]:
fwd_with_activation(512, 100, math.sqrt(1./512), tanh=True)

layer 99 - mean: -0.005211, std: 0.059854


In [206]:
fwd_with_activation(512, 100, math.sqrt(1./512), relu=True)

layer 99 - mean: 0.000000, std: 0.000000


### Xavier (Glorot) initialization

- Xavier Glorot and Yoshua Bengio (in their paper, _Understanding the difficulty of training deep feedforward neural networks_), they proposed the initialization scheme of drawing from uniform distribution bounded between $\pm\frac{\sqrt{6}}{\sqrt{n_{i}+n_{i+1}}}$

- where $n_{i}$ is the number of input dimensions coming into the layer $i$ (aka *"fan-in"*); and 
- $n_{i+1}$ is the number of output dimensions coming out of layer $i$ (aka. *"fan-out"*)

In [187]:
def fwd_xavier(dim, nlayers, var_scale=1, tanh=False, relu=False):
    x = torch.randn(dim)
    bound = math.sqrt(6.)/math.sqrt(dim+dim)
    for i in range(nlayers):
        w = torch.Tensor(dim,dim).uniform_(-bound,bound)*var_scale
        if tanh: x = torch.tanh(w@x)
        elif relu: x = torch.relu(w@x)
        else: x = w@x
        if torch.isnan(x.mean()) or x.mean()==0:
            print('layer %d - mean: %f, std: %f'%(i,x.mean(),x.std())); return
    print('layer %d - mean: %f, std: %f'%(i,x.mean(),x.std()))

- Xavier init works fairly decent with vanilla forward matrix muplications

In [190]:
fwd_xavier(512, 100)

layer 99 - mean: -0.040578, std: 0.909953


- however, it does not work well when activation units are added into the network; as the variance has mostly vanished

In [191]:
fwd_xavier(512, 100, tanh=True)

layer 99 - mean: 0.008958, std: 0.096193


In [192]:
fwd_xavier(512, 100, relu=True)

layer 99 - mean: 0.000000, std: 0.000000


### Magnitude of Std-dev when including ReLu

- Let's examine what happens to the average mean and standard-deviatiion of the output from a single layer with ReLu activation

In [198]:
dim, mean, std, trials = 512, 0, 0, int(1e4)
for i in range(trials):
    w, x = torch.randn(dim,dim), torch.randn(dim)
    y = torch.relu(w@x)
    mean += y.mean().item()
    std += y.std().item()
print('average mean: {}; average std-dev: {}'.format(mean/trials, std/trials))

average mean: 9.021990146636963; average std-dev: 13.182903115558624


### Kaiming initialization

- in Kaiming initialization, they discovered that the average standard deviation of the output from a single layer of the network with ReLu activation is very close to $\sqrt{\frac{dim_{input}}{2}}$

In [219]:
print(math.sqrt(512/2))

16.0


- let's try to normalize our output activations by this magnitude and observe that the average standard deviation is much closer to 1

In [214]:
dim, mean, std, trials = 512, 0, 0, int(1e4)
for i in range(trials):
    w, x = torch.randn(dim,dim), torch.randn(dim)
    y = torch.relu(w@x) * math.sqrt(2/512)
    mean += y.mean().item()
    std += y.std().item()
print('average mean: {}; average std-dev: {}'.format(mean/trials, std/trials))

average mean: 0.5632098125249148; average std-dev: 0.8239358275055886


In [215]:
dim, mean, std, trials = 512, 0, 0, int(1e4)
for i in range(trials):
    w, x = torch.randn(dim,dim), torch.randn(dim)
    y = torch.relu(w@x) * math.sqrt(3/512)
    mean += y.mean().item()
    std += y.std().item()
print('average mean: {}; average std-dev: {}'.format(mean/trials, std/trials))

average mean: 0.689736943590641; average std-dev: 1.0087386183917522


- Let's now try our simple pass-forward network with Kaiming initialization scheme

In [267]:
fwd_with_activation(512, 100, math.sqrt(2./512), relu=True)

layer 99 - mean: 0.374733, std: 0.598513


In [272]:
fwd_with_activation(512, 1000, math.sqrt(2./512), relu=True)

layer 999 - mean: 0.141731, std: 0.212768


- on average, the Kaiming initialization does keep the resulting mean and variance of our output activations (through the 100-layer (or even 1000-layer) network at a healthy level that neither explodes nor vanishess