# RMSNorm

## Introduction

Deep neural networks, especially Transformers, stack dozens or even hundreds of layers. Without normalization, intermediate activations can explode (grow uncontrollably) or vanish (shrink toward zero). This leads to unstable gradients and poor convergence. Normalization keeps activations and gradients in a healthy range, preventing exploding or vanishing values.

### Background

**LayerNorm** was the normalization used in the original Transformer from 2016. It re-centers and re-scales invariance property, stabilizing training. It is also insensitive to batch size. (Ba et al. (2016). [Layer Normalization](https://arxiv.org/abs/1607.06450)). It normalizes each token’s feature vector by re-centering to zero mean and re-scaling to unit variance, then scale and shift it with learned parameters $\gamma$ and $\beta$.


$$
\mathrm{LN}(x)=\gamma\odot\frac{x-\mu}{\sqrt{\sigma^2+\epsilon}}+\beta
$$

Where

$$
\mu=\frac{1}{d}\sum_{i=1}^d x_i,\ \sigma^2=\frac{1}{d}\sum_{i=1}^d(x_i-\mu)^2
$$

- $\mu$: mean of $x$
- $\sigma^2$: variance across $x$
- $\epsilon$: a tiny positive constant to avoid division-by-zero 

Edit: [gist](https://gist.github.com/furixturi/0ad389170baec66df12b2a31ee7e13f0)

Sample implementation:

In [1]:
import torch, torch.nn as nn

class LayerNorm(nn.Module):
    def __init__(self, d, eps=1e-5):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d))
        self.beta  = nn.Parameter(torch.zeros(d))
        self.eps = eps
    def forward(self, x):  # x: [..., d]
        mu = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, unbiased=False, keepdim=True)
        xhat = (x - mu) / torch.sqrt(var + self.eps)
        return self.gamma * xhat + self.beta

## RMSNorm

RMSNorm was introduced in 2019 (B. Zhang, R. Sennrich [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)). It removed mean subtraction, scaling only by root-mean-square. The hypothesis is that LayerNorm's success is mainly attributed to re-scaling invariance rather than re-centering. The quality is similar to LayerNorm and the computation is slightly cheaper/faster.

LLaMA-family adopted and popularized RMSNorm and it becomes the de facto normalization of modern LLMs.

Formula:

$$
RN(x_i) = \frac{x_i}{RMS(x)}g_i, \quad \text{RMS}(a) = \sqrt{\frac{1}{n}\sum_{i=1}^n a_i^2}
$$

Where:
- $x \in \mathbb{R}^n$ is the input vector
- $g \in \mathbb{R}^n$ is a learned scaling parameter

### Implementation

In [2]:
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-8):
        super().__init__()
        self.g = nn.Parameter(torch.ones(d))
        self.eps = eps
    def forward(self, x):
        rms = torch.sqrt((x * x).mean(dim=-1, keepdim=True) + self.eps)
        return self.g * (x / rms)