In [35]:
import torch
import math

### Why you need a good init

To understand why initialization is important in a neural net, we'll focus on the basic operation you have there: matrix multiplications. So let's just take a vector `x`, and a matrix `a` initiliazed randomly, then multiply them 100 times (as if we had 100 layers). 

In [36]:
x = torch.randn(512)
a = torch.randn(512,512)

In [37]:
for i in range(100): x = a @ x

In [38]:
x.mean(),x.std()

(tensor(nan), tensor(nan))

The problem you'll get with that is activation explosion: very soon, your activations will go to nan. We can even ask the loop to break when that first happens:

In [39]:
x = torch.randn(512)
a = torch.randn(512,512)

In [40]:
for i in range(100): 
    x = a @ x
    if x.std() != x.std(): break

In [41]:
i

27

It only takes around 30 multiplications! On the other hand, if you initialize your activations with a scale that is too low, then you'll get another problem:

In [42]:
x = torch.randn(512)
a = torch.randn(512,512) * 0.01

In [43]:
for i in range(100): x = a @ x

In [44]:
x.mean(),x.std()

(tensor(0.), tensor(0.))

Here, every activation vanished to 0. So to avoid that problem, people have come with several strategies to initialize their weight matices, such as:
- use a standard deviation that will make sure x and Ax have exactly the same scale
- use an orthogonal matrix to initialize the weight (orthogonal matrices have the special property that they preserve the L2 norm, so x and Ax would have the same sum of squares in that case)
- use [spectral normalization](https://arxiv.org/pdf/1802.05957.pdf) on the matrix A  (the spectral norm of A is the least possible number M such that `torch.norm(A@x) <= M*torch.norm(x)` so dividing A by this M insures you don't overflow. You can still vanish with this)

### The magic number for scaling

Here we will focus on the first one, which is the Xavier initialization. It tells us that we should use a scale equal to `1/math.sqrt(n_in)` where `n_in` is the number of inputs of our matrix.

In [45]:
import math

In [46]:
x = torch.randn(512)
a = torch.randn(512,512) / math.sqrt(512)

In [47]:
for i in range(100): x = a @ x

In [48]:
x.mean(),x.std()

(tensor(0.0982), tensor(4.2881))

And indeed it works. Note that this magic number isn't very far from the 0.01 we had earlier.

In [49]:
1/ math.sqrt(512)

0.044194173824159216

But where does it come from? It's not that mysterious if you remember the definition of the matrix multiplication. When we do `y = a @ x`, the coefficients of `y` are defined by

$$y_{i} = a_{i,0} x_{0} + a_{i,1} x_{1} + \cdots + a_{i,n-1} x_{n-1} = \sum_{k=0}^{n-1} a_{i,k} x_{k}$$

or in code:
```
y[i] = sum([c*d for c,d in zip(a[i], x)])
```

Now at the very beginning, our `x` vector has a mean of roughly 0. and a standard deviation of roughly 1. (since we picked it that way).

In [50]:
x = torch.randn(512)
x.mean(), x.std()

(tensor(-0.0648), tensor(0.9609))

NB: This is why it's extremely important to normalize your inputs in Deep Learning, the intialization rules have been designed with inputs that have a mean 0. and a standard deviation of 1.

If you need a refresher from your statistics course, the mean is the sum of all the elements divided by the number of elements (a basic average). The variance of a dataset gives a measure of whether the data stays close to the mean or generally has values that are far away from the mean. In other words, variance gives a measure of the spread of the data. It's computed by the following formula:

$$\sigma^2 = \frac{1}{n}\left[(x_{0}-m)^{2} + (x_{1}-m)^{2} + \cdots + (x_{n-1}-m)^{2}\right]$$


where m is the mean and $\sigma^2$ (the greek letter sigma) is the variance. The square root of the variance is called standard deviation ($\sigma$). Clearly, $\sigma$ and $\sigma^2$ directly depend on each other. Thus, changing one leads to a change in the other. 

If we go back to `y = a @ x` and assume that we chose weights for `a` that also have a mean of 0, we can compute the variance of `y` quite easily. Since it's random, and we may fall on bad numbers, we repeat the operation 100 times.

In [52]:
mean, var = 0., 0.
n_dim = 512
n_iter = 100
for i in range(n_iter):
    x = torch.randn(n_dim)
    a = torch.randn(n_dim, n_dim)
    y = a @ x
    mean += y.mean().item()
    var  += (y - y.mean()).pow(2).mean().item()
    
mean/n_iter, var/n_iter

(-0.000560310110449791, 508.57024475097654)

Now that looks very close to the dimension of our matrix 512. And that's no coincidence! When you compute y, you sum 512 product of one element of a by one element of x. So what's the mean and the variance of such a product? We can show mathematically that as long as the elements in a and the elements in x are independent, the mean is 0 and the variance is 512. This can also be seen experimentally:

In [58]:
mean,sqr = 0.,0.
n_iter = 10000
for i in range(n_iter):
    x = torch.randn(512)
    a = torch.randn(512) #just like one row of a
    y = a @x
    mean += y.item()
    sqr  += y.pow(2).item()
mean/10000,sqr/10000

(-0.351292731320858, 517.5591897146537)

### Proof that $Y \sim U(0, 512)$

We are given that $x$ and $a$ are from a distribution with mean = 0 and variance = 1. That is, to create $x$, we pick 512 random numbers from a distribution with 0 mean and std 1. Similarly, to create $a$, we pick (512 * 512) random numbers from a distribution with mean 0 and std 1. Then $i^{th}$ element of $y$ is calculated by multiplying 512 elements of $a$ (i.e. $a[i]$) with 512 elements of $x$.

$$y_{i} = a_{i,0} x_{0} + a_{i,1} x_{1} + + a_{i,511} x_{511} = \sum_{k=0}^{511} a_{i,k} x_{k}$$

Let $A$ be the random variable from which $a$ is taken, and $X$ be the random variable from which $x$ is taken. Let $y$ be the random variable from which $Y$ is sampled.  We know that one element of $Y$ is created by multiplying 512 elements from $A$ and $X$ with each other. That is, we sample 512 elements from $A$, 512 elements from $X$ and multiply them with each other. Thus, so far we have:


$
\begin{align}
& A \sim U(0, 1) \\
& E \sim U(0, 1) \\ 
& E[A] = 0 \\
& E[X] = 0 \\ 
& Var[A] = Std[A] = 1 \\ 
& Var[X] = Std[X] = 1 \\ 
\end{align}
$

$
\begin{align}
Y = \sum_{k=0}^{511} A*X
\end{align}
$

Let's start by calculating the mean of Y

#### Expectation (Mean) of Y


$
\begin{align}
E[Y] & = E[AX] \\
& = E[A] * E[X] = 0 \ (\because E[A] = E[X] = 0)
\end{align}
$

#### Variance of Y

Let's first calculate the variance of $AX$. That is, what would be the variance if we pick one element randomly from $A$ and $X$ and then multiply them?

$
\begin{align}
Var[AX] & = Var(A)*(E(X))^2 + Var(X)*(E(A))^2 + Var(A)*Var(X) \\
 & = Var(A) * Var(X) = 1
\end{align}
$

We know that Y is formed by summing 512 such elements or

$
\begin{align}
Y = \sum_{k=0}^{511} A*X
\end{align}
$

Thus

$
\begin{align}
Var[Y] & = Var[\sum_{k=0}^{511}A * X] \\
    & = \sum_{k=0}^{511} Var[AX] &(\text{A and X are independent}) \\
& = \sum_{k=0}^{511} 1  &(\text{Var[AX] = 1 from above}) \\\\
& = 512
\end{align}
$

In other words, $Y \sim U(0, 512)$ which is bad, since Y now varies a lot! The experiment is reproduced below for ready reference. *Each of the ys have a large variance!*

In [55]:
mean, var = 0., 0.
n_iter = 10000
n_dim = 512
for i in range(n_iter):
    x = torch.randn(n_dim)
    a = torch.randn(n_dim) #just like one row of a
    y = a @x
    mean += y.item()
    var  += y.pow(2).item()
    
mean/n_iter, var/n_iter

(-0.3848510993003845, 505.62580416496115)

Then we sum 512 of those things that have a mean of zero, and a mean of squares of 512, so we get something that has a mean of 0, and mean of square of 512. If we scale the weights of the matrix a and divide them by this math.sqrt(512), we will be picking elements of a from a distribution with 0 mean and variance = $1 / 512$. This will in turn give us a distribution of y in which each element has 0 mean and std = 1, thus allowing us to repeat the product has many times as we want won't overflow or vanish. The proof and the experiments follow.

### Proof that Y is now $\sim U(0, 1)$

After the normalization, when we create $A$, instead of sampling from $U(0, 1)$, we sample from $U(0, 1 / \sqrt(512))$. This is the case since we divide every element of $a$ by $1/sqrt(512)$. We will now prove that this leads to $Y$ getting a better distribution. We will stick to our example of 512.

$
\begin{align}
& A \sim U(0, 1 / \sqrt512) \\
& E \sim U(0, 1) \\ 
& E[A] = 0 \\
& E[X] = 0 \\ 
& Var[A] = 1 / 512, Std[A] = 1 / \sqrt(512) \\ 
& Var[X] = Std[X] = 1 \\ 
\end{align}
$

#### Expectation (Mean) of Y (unchanged)


$
\begin{align}
E[Y] & = E[AX] \\
& = E[A] * E[X] = 0 \ (\because E[A] = E[X] = 0)
\end{align}
$

#### Variance of Y

As before, let's first calculate the variance of $AX$. That is, what would be the variance if we pick one element randomly from $A$ and $X$ and then multiply them?

$
\begin{align}
Var[AX] & = Var(A)*(E(X))^2 + Var(X)*(E(A))^2 + Var(A)*Var(X) \\
 & = Var(A) * Var(X) = 1 / 512
\end{align}
$

$
\begin{align}
Var[Y] & = Var[\sum_{k=0}^{511}A * X] \\
    & = \sum_{k=0}^{511} Var[AX] &(\text{A and X are independent}) \\
& = \sum_{k=0}^{511} 1 / 512  &(\text{Var[AX] = 1 from above}) \\\\
& = 1
\end{align}
$

In other words, $Y \sim U(0, 1)$ which is what we wanted! Let's do some quick experiments to make sure this holds:

In [56]:
mean, var = 0., 0.
n_iter = 10000
n_dim = 512
for i in range(n_iter):
    x = torch.randn(n_dim)
    a = torch.randn(n_dim) / math.sqrt(dim) #just like one row of a
    y = a @x
    mean += y.item()
    var  += y.pow(2).item()
    
mean/n_iter, var/n_iter

(0.006510300114750862, 1.0147866140232764)

Works! Back to our original problem:

In [57]:
mean, var = 0., 0.
n_dim = 512
n_iter = 100
for i in range(n_iter):
    x = torch.randn(n_dim)
    a = torch.randn(n_dim, n_dim) / math.sqrt(n_dim)
    y = a @ x
    mean += y.mean().item()
    var  += (y - y.mean()).pow(2).mean().item()
    
mean/n_iter, var/n_iter

(0.003834125765133649, 0.9961219137907028)

Makes sense!

### Adding ReLU in the mix

We can reproduce the previous experiment with a ReLU, to see that this time, the mean shifts and the standard deviation becomes 0.5. This time the magic number will be `math.sqrt(2/512)` to properly scale the weights of the matrix.

In [None]:
mean,sqr = 0.,0.
for i in range(10000):
    x = torch.randn(1)
    a = torch.randn(1)
    y = a*x
    y = 0 if y < 0 else y.item()
    mean += y
    sqr  += y ** 2
mean/10000,sqr/10000

(0.313308998029304, 0.4902977162997965)

We can double check by running the experiment on the whole matrix product.

In [None]:
mean,sqr = 0.,0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512)
    y = a @ x
    y = y.clamp(min=0)
    mean += y.mean().item()
    sqr  += y.pow(2).mean().item()
mean/100,sqr/100

(9.005213165283203, 257.3701454162598)

Or that scaling the coefficient with the magic number gives us a scale of 1.

In [None]:
mean,sqr = 0.,0.
for i in range(100):
    x = torch.randn(512)
    a = torch.randn(512, 512) * math.sqrt(2/512)
    y = a @ x
    y = y.clamp(min=0)
    mean += y.mean().item()
    sqr  += y.pow(2).mean().item()
mean/100,sqr/100

(0.5643280637264252, 1.0055165421962737)

The math behind is a tiny bit more complex, and you can find everything in the [Kaiming](https://arxiv.org/abs/1502.01852) and the [Xavier](http://proceedings.mlr.press/v9/glorot10a.html) paper but this gives the intuition behing those results.