# RMSNorm Gradient Derivation

Equation 9 of "Root Mean Square Layer Normalization" gives the gradient of RMSNorm with respect to the input $x$:

$$\frac{\partial \mathcal{L}}{\partial x_i}=\frac{\partial \mathcal{L}}{\partial y_i}\frac{w_i}{r}-\frac{w_i x_i}{n r^3}\sum_j\frac{\partial \mathcal{L}}{\partial y_j}x_j,$$

where $r=\sqrt{\frac{1}{n}\sum_j x_j^2+\epsilon}$. This notebook implements this formula and checks it against autograd.

In [None]:
import torch

def rmsnorm_forward(x, weight, eps=1e-8):
    rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
    return weight * x / rms

# Example input
d = 4
x = torch.randn(2, d, requires_grad=True)
weight = torch.ones(d, requires_grad=True)

y = rmsnorm_forward(x, weight)
loss = y.sum()
loss.backward()
print('Autograd grad:', x.grad)

In [None]:
# Manual gradient based on Eq. 9

def rmsnorm_backward(dy, x, weight, eps=1e-8):
    n = x.shape[-1]
    rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)
    dot = (dy * x).sum(-1, keepdim=True)
    dx = (dy * weight) / rms - (weight * x / (n * rms.pow(3))) * dot
    return dx

# Verify against autograd
x.grad.zero_()
loss = rmsnorm_forward(x, weight).sum()
loss.backward()
manual_dx = rmsnorm_backward(torch.ones_like(y), x.detach(), weight.detach())
print('Manual grad:', manual_dx)