In [2]:
import sentencepiece
import torch
import torch.nn as nn

### RMSNorm instead of LayerNorm

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.emb_dim = emb_dim
        # only 1 parameter
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()
    
    def forward(self, x):
        # RMS -> Root Mean Square
        # x^2 -> mean -> root
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        # normalize input by this
        return (x_normed * self.weight).to(dtype=x.dtype)

In [4]:
# verify

torch.manual_seed(123)
example = torch.randn(2,3,4)
rms_norm = RMSNorm(emb_dim=example.shape[-1])
rms_norm_pt = torch.nn.RMSNorm(normalized_shape=example.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example), rms_norm_pt(example))