## Limitations of **Batch Normalization**:
- You need to maintain running means
- Tricky for RNNs. Do you need normalization for each step?
- Doens't work with small batch sizes
- Need to compute means and variances across devices in distributed training

## Layer Normalization
- Layer Normalization is a simpler normalization method that works on a wider range of settings
- It transforms the inputs to have zero mean and variance **across all features** 
- Note: BN fixes the zero mean and unit variance **for each element**
- Layer normalization does it for each batch across all elements!

In [1]:
from typing import Union, List
import torch
from torch import nn, Size
from labml_helpers.module import Module

- `normalized_shape` $S$ is the shaoe of the elements (except the batch). 
- `eps` is $\epsilon$, used for numerical stability
- `elementwise_affine` is whether to scale and shift the normalized value

In [2]:
class LayerNorm(Module):
    def __init__(self, normalized_shape: Union[int, List[int], Size], *, eps: float=1e-5, elementwise_affine: bool=True):
        super().__init__()

        # NOTE: Convert normalized_shape to torch.Size
        if isinstance(normalized_shape, int):
            normalized_shape = torch.Size([normalized_shape])
        elif isinstance(normalized_shape, list):
            normalized_shape = torch.Size(normalized_shape)
        assert isinstance(normalized_shape, torch.Size)

        self.normalized_shape = normalized_shape
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        # NOTE: Create parameters y and ß for gain and bias
        if self.elementwise_affine:
            self.gain = nn.Parameter(torch.ones(normalized_shape))
            self.bias = nn.Parameter(torch.zeros(normalized_shape))
    
    """ NOTE: x is a tensor of shape [*, S[0], S[1], ..., S[n]]. 
             * could be any number of dimensions. For example, in 
             an NLP task this will be [seq_len, batch_size, features] """
    def forward(self, x: torch.Tensor):
        # NOTE: Sanity check to make sure the shapes match
        assert self.normalized_shape == x.shape[-len(self.normalized_shape):]
        dims = [-(i+1) for i in range(len(self.normalized_shape))]
        mean = x.mean(dim=dims, keepdim=True)
        mean_x2 = (x**2).mean(dim=dims, keepdim=True)
        var = mean_x2 - mean ** 2 # NOTE: Var(x) = E[X^2] - E[X]^2
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        if self.elementwise_affine:
            x_norm = self.gain * x_norm + self.bias
        return x_norm

In [3]:
def _test():
    from labml.logger import inspect

    x = torch.zeros([2, 3, 2, 4])
    inspect(x.shape)
    ln = LayerNorm(x.shape[2:])

    x = ln(x)
    inspect(x.shape)
    inspect(ln.gain.shape)

In [4]:
_test()