Weight initialization for bad mathematicians
=================================


Well I like math, I'm just bad at it, very bad. But this should not stop me to understand deep learning, actually the best deep learning practitioner that I know, Jeremy Howard also proclaims he is a bad mathematician.  

Weight initialisation matters, it matters A LOT! This simple "trick" can be the difference between convergence and total failure.  

As described by this great paper, the loss surface of a neural network can be very chaotic (and as explained by the paper, this heavily depends on the architecture you're using).  

t loss landscapes for all the networks considered seem to be partitioned into a well-defined region of low loss value and convex contours, surrounded by a well-defined region of high loss value and non-convex contours.  

Near the local minimum the loss structure is well behaved and very smooth, but as we move away to the regions of high loss value, the contours start to become sharp and chaotic.
With a good weight initialization strategy, your initial loss will likely lie in the well-behaved region, and your model might never see the chaotic space that lies on the high loss region.  

If your model already starts in this very chaotic high loss space, your gradients will be very uninformative and and training will be impossible, your model will keep jumping here and there and will not be able to converge.

Let's take a look at a very simple example to see the harm that a very normal initialization can cause:

Before we start, let's import the necessary dependencies and define some helper functions.

In [1]:
import tensorflow as tf
import math

In [23]:
# Monkey patching some utility functions, you can ignore these
tf.Tensor.mean = lambda o, **kwargs: tf.reduce_mean(o, **kwargs)
tf.Tensor.pow = lambda o, p, **kwargs: tf.pow(o, p, **kwargs)

In [2]:
def lin(x, w): return x@w
def relu(x): return tf.math.maximum(0, x)
def stats(x):
  mean, std = float(tf.reduce_mean(x)), float(tf.math.reduce_std(x))**2
  print(f'Mean: {mean}\nVar: {std}')
def forward(x, ws, act=None):
  for w in ws: x = act(lin(x, w)) if act else lin(x, w)
  return x

Our inputs will come from a normal distribution since data is almost always normalized in ML.  

In [3]:
x = tf.random.normal((1, 1000))

Our weights will also come from a normal distribution. Very harmless huh? 

In [4]:
w = tf.random.normal((1000, 1000))

Let's do a forward pass and take a look on what happens with the mean and std.

In [5]:
stats(forward(x, [w]))

Mean: -0.2934545874595642
Var: 1067.616436898301


Woah! After the very first layer the variance is already huge! You think that's not concerning enough? Let's see what happens if we go deep and stack 50 of those layers.

In [49]:
ws = [tf.random.normal((1000, 1000)) for _ in range(50)]
stats(forward(x, ws))

Mean: nan
Var: nan


And there you go, it already exploded. What happens is that our activations get so huge that the computer cannot keep track anymore, and the we get `Not A Number`.

So why we have such high numbers for variance?  
It ends up that when we are calculating the matrix multiplication `x@w` we sum 512 products of one element of `w` by one element of `x`. The mean of each of this little products is 0, so the final mean tends to be close to 0, but the variance is 1 for each of them, so when we sum all together all of the single variances adds up, and we end up with a variance that tends towards the number of units we have in the layer.  
As we stack more layers the variance keeps growing exponentially.  

The solution? Scale the weights in such a way that the variance of the output of each layer is 1. This way we can stack as many layers as we want and the output of every layer will have a variance of one.

So the question becomes, how to scale this layers?  
Well, we already concluded that the output variance is going to be equal to the number of inputs to the layer, 

Recall that variance is the average of how far our values are from the mean, let's try doing that:

In [21]:
x = tf.random.normal((1000,))
(x - x.mean()).mean()

<tf.Tensor: id=215, shape=(), dtype=float32, numpy=-1.7642975e-08>

Cool, so our variance is zero? Well... **NO**.  

Some values are going to be greater than the mean and some are going to be smaller, by definition the mean is going to have the same **sum of distances** from the right and from the left side, so when we sum all of these little distances from the mean together, we can zero.

What kills us here is that some values are positives and some are negatives, so what we do? We square all the distances so they're are always positive:  


In [24]:
(x - x.mean()).pow(2).mean()

<tf.Tensor: id=229, shape=(), dtype=float32, numpy=0.9731436>

And thus we have variance, easy peasy. And this is the mathematical formula for what we just coded:  

$$\frac{1}{n}{\displaystyle\sum_{i=1}^{n}(x_i - \mu)^2}$$  

Remember our original problem, we want to change all $x_i$ so we can modify the variance. How would we make the variance of the previous example 10 times smaller? We can try dividing all $x$ by 10:  

In [25]:
x1 = x/10
(x1 - x1.mean()).pow(2).mean()

<tf.Tensor: id=238, shape=(), dtype=float32, numpy=0.009731436>

But that made the variance 100 times smaller, that is because we are squaring all of our values.  

So if we want the make our variance 10 times smaller we have to divide all our values with $\sqrt{10}$

In [26]:
x1 = x/math.sqrt(10)
(x1 - x1.mean()).pow(2).mean()

<tf.Tensor: id=247, shape=(), dtype=float32, numpy=0.097314365>

In our original problem the variance of the output was equal to the number of inputs to that layer. We want the variance to be `n_inputs` smaller, so we just have to divide our weights with $\sqrt{\text{n_inputs}}$

In [27]:
x = tf.random.normal((1, 1000))

In [28]:
w = tf.random.normal((1000, 1000)) / math.sqrt(1000)
stats(forward(x, [w]))

Mean: 0.039295364171266556
Var: 1.0486676543465023


And there we have it! The variance of the output is 1! Let's try doing the same with 50 stacked layers:  

In [29]:
ws = [tf.random.normal((1000, 1000))/math.sqrt(1000) for _ in range(50)]
stats(forward(x, ws))

Mean: -0.05769466981291771
Var: 0.9657826813236312


# Talk about relu and kaiming init

## Push notebook

In [None]:
import jovian
jovian.commit(secret=True)

<IPython.core.display.Javascript object>

[jovian] Saving notebook..


<IPython.core.display.Javascript object>