# Importance of weight initialization and data normalization

**Weight Initialization is the critical balancing adjustments** without which training will not work. Network architecture, optimizer, hyper parameter **do NOT matter without proper initializations**.

* [Building makemore Part 3: Activations & Gradients, BatchNorm](https://www.youtube.com/watch?v=P6sfmUTpUmc) - Andrej Karpathy (MUST)
* [Stanford CS232 Lecture 6 | Training Neural Networks I](https://www.youtube.com/watch?v=wEoyxE0GP2M)
* [Deep Learning AI - The importance of effective initialization](https://www.deeplearning.ai/ai-notes/initialization/index.html) - MUST
* [He Initialization](https://arxiv.org/pdf/1502.01852.pdf)
* [Variance of product of multiple independent random variables](https://stats.stackexchange.com/questions/52646/)

> To prevent the gradients of the network’s activations from vanishing or exploding, we will stick to the following rules of thumb:
> 
> 1. The mean of the activations should be zero.
> 2. The variance of the activations should stay the same across every layer.
>
> Under these two assumptions, the backpropagated gradient signal should not be multiplied by values too small or too large in any layer. It should travel to the input layer without exploding or vanishing.
> n other words, all the **weights of layer ```l``` are random samples from a normal distribution** with mean ```μ=0``` and variance ```v=1/N(l-1)``` where ```N(l-1)``` is the dimensions of the input (number of outputs or number of neurons of the previous layer).

Initialize W with the same value is **variance == 0** either the value is 0 or other values.

* [NN - 18 - Weight Initialization 1 - What not to do?](https://youtu.be/eoNVmZDnn9w?t=250)
> 
> Back Propagation will move the same direction in all $W$ if variance of W is 0 (set the same value in W, e.g. set all 0 in W).
> <img src="image/initiailzation_sigmoid_backprop_impoact.png" align="left">

* [NN - 18 - Weight Initialization 1 - What not to do?](https://youtu.be/eoNVmZDnn9w?t=482)
> 
> <img src="image/initialization_what_to_do.png" aligh="left"/>

## Transformer Weight Initialization

* [T-Fixup Improving Transformer Optimization Through Better Initialization](https://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf)
* [Effective Theory of Transformers at Initialization](https://arxiv.org/pdf/2304.02034.pdf)
* [Improving Deep Transformer
with Depth-Scaled Initialization and Merged Attention](https://aclanthology.org/D19-1083.pdf)
* [FIXUP INITIALIZATION:
RESIDUAL LEARNING WITHOUT NORMALIZATION](https://arxiv.org/pdf/1901.09321.pdf)
* [MS Research - DeepNet 1000 layer transformer](https://arxiv.org/pdf/2203.00555.pdf)
* [Meta AI - Norm Former](https://arxiv.org/pdf/2110.09456.pdf) - Address larger gradient at lower layer by Pre-LN.
* [Learning Deep Transformer Models for Machine Translation](https://arxiv.org/pdf/1906.01787.pdf)
* [ReZero is All You Need: Fast Convergence at Large Depth](https://arxiv.org/pdf/2003.04887.pdf)
* [Training Tips for the Transformer Model](https://arxiv.org/pdf/1804.00247.pdf)
* [PyTorch - Transformer Initialization #72253](https://github.com/pytorch/pytorch/issues/72253)

### Initialization Code
* [BertConfig][4]

> initializer_range (float, optional, **defaults to 0.02**) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices.

* [HuggingFace Bert Weight Initialization code](https://github.com/huggingface/transformers/blob/a9aa7456ac/src/transformers/modeling_bert.py#L520-L530)

```
    def _init_weights(self, module):
        """ Initialize the weights """
        if isinstance(module, (nn.Linear, nn.Embedding)):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) # <--- initializer_range = 0.02
        elif isinstance(module, BertLayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        if isinstance(module, nn.Linear) and module.bias is not None:
            module.bias.data.zero_()

```

* [PyTorch Tutorial - LANGUAGE TRANSLATION WITH NN.TRANSFORMER AND TORCHTEXT](https://pytorch.org/tutorials/beginner/translation_transformer.html#seq2seq-network-using-transformer)

```
for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
```

* [Nano GPT](https://github.com/karpathy/nanoGPT/blob/master/model.py#L141)

```
        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
```


---

# Problems due to Inproper Initializations

## Waste of training cycles

* [Building makemore Part 3: Activations & Gradients, BatchNorm](https://youtu.be/P6sfmUTpUmc?t=259)

If the weights are not properly, initial training cycles will be spent to mitigate it -> Manifest as **initial large loss** being squashed down quickly (hockey stick like learning curve).

<img src="image/nn_weight_initialization_too_large.png" align="left" width=400/>

  [4]: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertConfig



## Vanishing Gradients

If W is initialized to small values, layer output ```y=Wn@W(n-1)@...@W1X``` will be diminished to zero for ReLU, Tanh activations where activation is 0 for 0 input (**result depends on the activation functions**). Then  input X(i) to the next layer i+1 will be zero. Then the gradient update to W(i+1) by X(i) will be zero, hence there is no gradient update. The neuron is dead with no update forever.

The hisgram is the activations of neurons (output of neurons) at each layer squashed to zero, hence no activations/signals from neurons.

<img src="image/vanishing_gradient.png" align="left" width=650/>




<img src="image/vanishing_gradient_example.png" align="left" width=650/>

[Building makemore Part 3: Activations & Gradients, BatchNorm](https://youtu.be/P6sfmUTpUmc?t=5455)

With smaler weight value initialzations, activations (output of neurons) and gradients get shrunk to around zero. Neurons become dead without signaling nor learning.

<img src="image/small_weight_initialization_impact_on_activation_gradient.png" align="left" width=650/>


## Exploding Gradients

<img src="image/explording_gradients.png" align="left"/>

With large weight initializations, activations are saturated (using tanh here) and gradients of the saturated tanh area gets to zero too.

<img src="image/large_weight_initialization_impact_on_activation_gradient.png" align="left" width=650/>

## Sub Optimal Gradient Update by non zero mean normalized data

Zero mean data initialization for better gradient update on W. W is updated with X. Hence if X is all positive, W update will be always negative.

<img src="image/data_normalize_zero_center.png" align="left" width=500/>

###  Rate of weight update at Gradient Desent

The rate by which the weight gets updated should be approx ```1e-3``` or 1/1000 ```log10(1e-3) == -3```.

* [Building makemore Part 3: Activations & Gradients, BatchNorm](https://youtu.be/P6sfmUTpUmc?t=5998)

> If the rate of wegit update is below 1e-3, the learning is too slow. If 1e-1, it is too high and too muuch chage.  
> ``` update_ration_log = (lr*W.grad).std() / W.data.std()).log10() ```  # Should be approx -3

<img src="image/proper_weight_update_ratio.png" align="left" width=750/>


# Solutions

1. Verify the weights during training that they are normally distributed with 0 mean and 1/D variance where D is input dimensions.
2. Verify the graidients are not 0 (vanished) or too large (how much is too large?) (exploding).
3. Use fit-for-purpose initialization e.g. Xavier, He depneing on the activation to use.
4. Use Batch or Layer Normalization.
5. Normalize input data.

Note: Weight Initialization depends on the activation function. Xavier initialization does not work with ReLU. Needs He.

### Demonstration of Xavier initialization does not work with ReLU
<img src="image/xavier_break_with_reul.png" align="left"/>

### Pytorch He Initialization

<img src="image/torch_he_initialization.png" align="left"/>

In [1]:
import numpy as np
from scipy.special import softmax

In [2]:
def log_loss(t, p):
    return np.sum(-t * np.log(p))

# Example

The network output logits ```y``` should be close to 0 because the model has no confidence of which class is true (for multi label classification). 

## Initial Large Loss

If the weights are not initialized to produce small (close to 0), the logits can be large resulting in a large loss.

In [9]:
t = np.array([0, 1, 0, 0])
y = np.array([67., 15., 39., 77.])
p = softmax(y, axis=-1)
p

print(f"output: {p}")
print(f"loss  : {log_loss(t=t, p=p)}")

output: [4.53978687e-05 1.18501106e-27 3.13899028e-17 9.99954602e-01]
loss  : 62.00004539889922


### Expected Loss

Ideal expected logits are ```y=[0,0,0,0]``` from which the loss value is 1.386

In [10]:
y = np.zeros(shape=4)
p = softmax(y, axis=-1)

print(f"output: {p}")
print(f"loss  : {log_loss(t=t, p=p)}")

output: [0.25 0.25 0.25 0.25]
loss  : 1.3862943611198906


### Mitigation

For matmul ```y=x@W.T```, initialize W with normal distribution and divide by square root of the input dimension. As in the image, the standard deviation or scale of the normal distribution on the left is ```sqrt(10)``` wider after the product ```x@w``` on the right where x and w has dimension D=10. Hence, make the standard deviation of W to ```1/sqrt(D)``` so that the variance of ```x@w``` will be 1.0.

* [Building makemore Part 3: Activations & Gradients, BatchNorm](https://youtu.be/P6sfmUTpUmc?t=1800)

<img src="image/product_of_two_normal_distributions.png" align="left"/>

In [5]:
t = np.array([0, 1, 0, 0])
M = len(t)  # number of labels
D = 8

x = np.random.normal(size=(D,))
W = np.random.normal(size=(M, D)) / np.sqrt(D)

y = x @ W.T
p = softmax(y, axis=-1)

print(f"output: {p}")
print(f"loss  : {log_loss(t=t, p=p)}")

output: [0.15056858 0.03632626 0.69038642 0.12271874]
loss  : 3.315214342930334


### Xavier Initialization

This is almost same with Xavier initialization.

* [Understanding Xavier Initialization In Deep Neural Networks](https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/)
* [Stanford CS230 Xavier Initialization](https://cs230.stanford.edu/section/4/)

<img src="image/xavier_initialization.png" align="left" width=600/>

In [8]:
# Originally Xavier initialization is using the dimensions of input and output, but using input only is common.
W2 = np.random.normal(loc=0, scale=2/np.sqrt(D+M), size=(M,D))
y = x @ W2.T
p = softmax(y, axis=-1)

print(f"output: {p}")
print(f"loss  : {log_loss(t=t, p=p)}")

output: [0.76176628 0.10239613 0.05155247 0.08428512]
loss  : 2.2789063413401904
