<a href="https://colab.research.google.com/github/heerboi/AI-from-scratch/blob/main/initialization_normalization_deep_dive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

References

1. Andrej Karpathy's insane godlike awesome series! - https://www.youtube.com/watch?v=P6sfmUTpUmc

2. https://arxiv.org/abs/1502.01852 - Kaiming init paper

3. https://arxiv.org/pdf/1502.03167 - Batch norm paper

In [None]:
# Imports
import torch
import torch.nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [None]:
class Linear():

    def __init__(self, in_features, out_features, bias=True):
        self.weight = torch.randn((in_features, out_features)) / (in_features**0.5)
        self.bias = None
        if bias:
            self.bias = torch.randn(out_features) / (out_features**0.5)

    def parameters(self):
        return [self.weight] + ([] if self.bias is None else [self.bias])

    def __call__(self, x):
        self.out = x @ self.weight

        if self.bias is not None:
            self.out += self.bias

        return self.out

In [None]:
class Tanh():
    def __call__(self, x):
        self.out = torch.tanh(x)
        return self.out
    def parameters(self):
        return []

For a layer with $d$-dimensional input $x = (x^{(1)}, ...., x^{(d)})$, we normalize each dimension as follows:

$$x̂^{(k)} = \frac{x^{(k)} - E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}$$

which normalizes the inputs to a particular layer (and squashes the BATCH dimension), giving the inputs to a particular node unit variance and zero mean. This is good for optimal learning as it reduces the _covariate shift_(basically gaussian initialization and matrix muls produces higher and higher values as we go deeper into the network, resulting in a lot of saturated neurons that don't end up learning!). By maintaining this gaussian just after the matrix mul and before the activation, the network learns much faster.

There is a HUGE catch, as you might've also guessed - this normalization can mess up the learning and change what a neuron "represents" for the outputs! Imagine the network uses a few neurons to make _sharp_ 0/1 decisions (saturated) AFTER learning. This is a neuron that is saturated after learning.

So, while this helps jumpstart the neural network, it might learn worse than when the weights are configured manually by scaling, etc. because batch norm stops neurons from ever becoming saturated!

To offset this normalization, then, the paper introduces two new parameters! (ouch, we're glad to have so much compute available to us now!): $γ^{(k)}$ and $β^{(k)}$ for each input $x^{(k)}$, to which is it applied as follows:

$$y^{(k)} = γ^{(k)}x̂^{(k)} + β^{(k)}$$

Note: These are trainable/learnable parameters such as weights and biases! This is helpful when a neuron learns optimally when un-normalized, as in that case: $γ^{(k)} = \sqrt{Var[x^{(k)}]}$ and $β^{(k)} = E[x^{(k)}]$ which recovers the original activations.

How do we compute the sqrt of variance (which is simply the std deviation), and the Expected value of $x^{(k)}$. Well, the latter is easy, because the expected value is just the mean of all observed values in the mini-batch (or the entire dataset).

The standard deviation is calculated similarly on the batch or entire dataset.

$γ$ and $β$ tensors are initialized with ones and zeros, respectively. (common sense)

In [None]:
class BatchNorm1D():
    def __init__(self, dim, training=True, momentum=0.1, eps=1e-05):
        self.training = training
        self.momentum = momentum
        self.eps = eps

        # 1 mean and variance for each neuron
        self.running_mean = torch.zeros(dim)
        self.running_variance = torch.ones(dim)
        self.gamma = torch.ones(dim)
        self.beta = torch.zeros(dim)

    def __call__(self, x):

        if self.training:
            x_mean = x.mean(dim=0, keepdim=True)
            x_var = x.var(dim=0, keepdim=True)

            # no grad because input x will have grad true so no grad so it doesnt get tracked
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * x_mean
                self.running_variance = (1 - self.momentum) * self.running_variance + self.momentum * x_var
        else:
            x_mean = self.running_mean
            x_var = self.running_variance

        normalized = (x - x_mean) / torch.sqrt(x_var + self.eps)
        self.out = self.gamma * normalized + self.beta

        return self.out

    def parameters(self):
        return [self.gamma, self.beta]