# LayerNorm and RMSNorm Analysis

本notebook实现

1. pytorch自动计算LayerNorm梯度
2. 手动计算LayerNorm梯度
3. 实现RMSNorm
4. 手动计算RMSNorm梯度
5. 分析RMSNorm的不变性
6. LayerNorm和RMSNorm对比梯度大小

# config

In [1]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
seed = 42
torch.manual_seed(seed)
# torch.cuda.manual_seed_all(seed)​ # if you have CUDA

<torch._C.Generator at 0x111d26b10>

In [2]:
batch_size = 1
seq_len = 4
d_model = 8
eps=1e-12
N = batch_size * seq_len * d_model

In [3]:
src = torch.randn(batch_size, seq_len, d_model) # 输入
trg = torch.randn(batch_size, seq_len, d_model) # 输入
# proj = nn.Linear(d_model, d_model)

In [4]:
def mse_loss(x, y):
    # bs, seq_len, d = y.size()
    # N = bs*seq_len*d
    return (((x - y) ** 2) * 0.5).mean()
loss = mse_loss(src, trg)
print(loss)

# pytorch自动计算LayerNorm梯度

In [5]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        # layernorm作用在(-1) 最后一维进行归一化
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        out_mean_var = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out_mean_var + self.beta # feature level
        return mean, var, out_mean_var, out

class ToyModel(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(ToyModel, self).__init__()
        self.w1 = nn.Linear(d_model,d_model)
        self.ln = LayerNorm(d_model, eps=1e-12)
        
    def forward(self, x):   
        w1_x = self.w1(x)
        # w1_x.retain_grad()
        mean, var, out_mean_var, out = self.ln(w1_x)
        return mean, var, out_mean_var, out, w1_x 

In [6]:
w_ln_model = ToyModel(d_model)
mean, var, out_mean_var, y, w1_x = w_ln_model(src)
w1_x.retain_grad()
y.retain_grad()
loss = mse_loss(y, trg)
loss.backward()

In [7]:
print(w_ln_model.ln.gamma.grad)
print(w_ln_model.ln.beta.grad)
print(w_ln_model.w1.weight.shape) # 只保留一小部分用于验证
print(w_ln_model.w1.weight.grad[:3,:3]) # 只保留一小部分用于验证

print(w1_x.grad.shape) # 只保留一小部分用于验证
print(w1_x.grad[:3,:3]) # 只保留一小部分用于验证

print(y.grad.shape)
print(y.grad[:3,:3])

# 手动计算LayerNorm梯度

根据公式我们有：

de/dy * dy/dx * dx/d_src

## de/dy

In [8]:
de_dy = (y-trg) / N # /N is mean
print(de_dy[:3,:3])
print(y.grad[:3,:3])

先求出不同的gamma和beta

In [9]:
dgamma = out_mean_var 
dgamma = (dgamma * de_dy).sum(1)
print(dgamma)
print(w_ln_model.ln.gamma.grad)

In [10]:
dbeta = de_dy.sum(1)
print(dbeta)
print(w_ln_model.ln.beta.grad)

## dy/dx

In [11]:
x_mean = w1_x - mean
x_var = var # 方差
x_std_var = torch.sqrt(x_var + eps)
print(x_std_var)

In [12]:
I = torch.ones(d_model)
diag_I = I.diag()
print(I)
print(diag_I)

I = torch.ones(d_model, d_model)

left = diag_I - 1/d_model * I
print(left)

In [13]:
# first we get token-level gradient 8x8, 
# 一个token有8维，ln前后的雅可比有grad_x_ln 8x8
# dx'1/dx1 dx'1/dx2 , ..  dx'1/dx8
# ...
# dx'8/dx1 dx'8/dx2 , ..  dx'8/dx8

# 四个token的梯度为de_dy = 4x8
# 一个token反向的梯度为  de_dy[1,: ]= 1x8
# 一个token的ln的梯度为de_dy @ grad_x_ln =  1x8

grad_x_ln = torch.zeros_like(w1_x)
for i in range(batch_size):
    for j in range(seq_len):
        d_ln = (left/x_std_var[i,j,0] -  (x_mean[i,j,:].outer(x_mean[i,j,:]))/(x_std_var[i,j,0]**3 * d_model)) * w_ln_model.ln.gamma
        grad_x_ln[i,j,:] = de_dy[i, j, :] @ d_ln
print(de_dy.shape)
print(grad_x_ln)

In [14]:
print(w1_x.grad)

## dx/d_w

In [15]:
dx_d_w = grad_x_ln.transpose(1,2) @ src
print(dx_d_w)
print(w_ln_model.w1.weight.grad)
ln_w_gradient = dx_d_w

综上我们验证成功手写layernorm梯度

# 实现RMSNorm

In [16]:
class RMSNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(RMSNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.eps = eps

    def forward(self, x):
        mean = (x**2).mean(-1, keepdim=True)
        out_mean = x / torch.sqrt(mean + self.eps) # root mean square
        out = self.gamma * out_mean 
        return mean, out_mean, out

class ToyModelRMS(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(ToyModelRMS, self).__init__()
        self.w1 = nn.Linear(d_model,d_model)
        self.rms = RMSNorm(d_model, eps=1e-12)
        
    def forward(self, x):   
        w1_x = self.w1(x)
        mean, out_mean, out = self.rms(w1_x)
        return mean, out_mean, out, w1_x 

In [17]:
w_rms_model = ToyModelRMS(d_model)
# 保持当前的
w_rms_model.w1.weight.data = w_ln_model.w1.weight.data.clone()
w_rms_model.w1.zero_grad()

mean, out_mean, y, w1_x = w_rms_model(src)
w1_x.retain_grad()
y.retain_grad()
loss = mse_loss(y, trg)
loss.backward()

In [18]:
print(w_rms_model.rms.gamma.grad)
# print(w_rms_model.rms.beta.grad)
print(w_rms_model.w1.weight.shape) 
print(w_rms_model.w1.weight.grad) 

print(w1_x.grad.shape)
print(w1_x.grad[:3,:3])

print(y.grad.shape)
print(y.grad[:3,:3])

# 手动计算RMSNorm梯度

根据公式我们有：

de/dy * dy/dx * dx/d_src

## de/dy

In [19]:
de_dy = (y-trg) / N # /N is mean
print(de_dy[:3,:3])
print(y.grad[:3,:3])

## de/dy * dy/dgamma

In [20]:
dgamma = out_mean
dgamma = (dgamma * de_dy).sum(1)
print(dgamma)
print(w_rms_model.rms.gamma.grad)

## dy/dx

In [21]:
x = w1_x
x_rms = torch.sqrt(mean + eps)
print(x_rms)

In [22]:
# first we get token-level gradient 8x8, 
# 一个token有8维，ln前后的雅可比有grad_x_ln 8x8
# dx'1/dx1 dx'1/dx2 , ..  dx'1/dx8
# ...
# dx'8/dx1 dx'8/dx2 , ..  dx'8/dx8

# 四个token的梯度为de_dy = 4x8
# 一个token反向的梯度为  de_dy[1,: ]= 1x8
# 一个token的ln的梯度为de_dy[i,:] @ grad_x_ln[i,:] =  1x8
I_diag = torch.diag(torch.ones(d_model))

grad_x_rms = torch.zeros_like(w1_x)
for i in range(batch_size):
    for j in range(seq_len):
        d_rms = I_diag/x_rms[i,j,0] -  (w1_x[i,j,:].outer(w1_x[i,j,:]))/(x_rms[i,j,0]**3 * d_model) *  w_rms_model.rms.gamma
        grad_x_rms[i,j,:] = de_dy[i, j, :] @ d_rms 
print(de_dy.shape)
print(grad_x_rms )

In [23]:
print(w1_x.grad)

## dx/dx_src

In [24]:
dx_d_w = grad_x_rms.transpose(1,2) @ src
print(dx_d_w)
print(w_rms_model.w1.weight.grad)

rms_w_gradient = dx_d_w

# 分析RMSNorm的不变性

## 前向输入不变性

见RMS Paper的4.1 Invariance Analysis

尺度不变性

$$
RMS(\alpha x) = \alpha RMS(x)
$$

In [28]:
w_rms_model = ToyModelRMS(d_model, eps=1e-12)
x = w_rms_model.w1(src)

x_10 = x * 10
mean, out_mean, out,  = w_rms_model.rms(x_10)
print(mean)
print(out_mean[0,0,:])
print(out[0,0,:])
x_rms_10 = torch.sqrt(mean + eps)

mean, out_mean, out,  = w_rms_model.rms(x)
print(out_mean[0,0,:])
print(out[0,0,:])
x_rms = torch.sqrt(mean + eps)

print('rms forward is scalar invariant')

## 反向梯度不变性

见RMS Paper的4.2 Gradient Analysis

In [34]:
# 
I_diag = torch.diag(torch.ones(d_model))

# scalar 10
grad_x_rms_10 = torch.zeros_like(x)
for i in range(batch_size):
    for j in range(seq_len):
        d_rms = I_diag/x_rms_10[i,j,0] -  (x_10[i,j,:].outer(x_10[i,j,:]))/(x_rms_10[i,j,0]**3 * d_model)
        # d_rms = d_rms * 10 # x' = x * 10
        grad_x_rms_10[i,j,:] = de_dy[i, j, :] @ d_rms 
print(grad_x_rms_10 )


# scalar 1
grad_x_rms = torch.zeros_like(x)
for i in range(batch_size):
    for j in range(seq_len):
        d_rms = I_diag/x_rms[i,j,0] -  (x[i,j,:].outer(x[i,j,:]))/(x_rms[i,j,0]**3 * d_model) 
        grad_x_rms[i,j,:] = de_dy[i, j, :] @ d_rms
print(grad_x_rms )

# rms norm 梯度和输入成反比
print(grad_x_rms_10 / grad_x_rms )

# LayerNorm和RMSNorm对比梯度大小

In [35]:
print(rms_w_gradient.norm())
print(ln_w_gradient.norm())

# 去Center化作用

ref：https://spaces.ac.cn/archives/8620#%E7%9B%B4%E6%8E%A5%E6%A0%87%E5%87%86%E5%8C%96

“
一个直观的猜测是，center操作，类似于全连接层的bias项，储存到的是关于预训练任务的一种先验分布信息，而把这种先验分布信息直接储存在模型中，反而可能会导致模型的迁移能力下降。所以T5不仅去掉了Layer Normalization的center操作，它把每一层的bias项也都去掉了。
”

# Reference

LayerNorm 梯度推导： [Layer Normalization, and how to compute its Jacobian for Backpropagation?](https://neuralthreads.medium.com/layer-normalization-and-how-to-compute-its-jacobian-for-backpropagation-55a549d5936f)

RMSNorm Paper: [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)

感谢 @Julian Lou 给出修改意见。