In [2]:
import sentencepiece
import torch
import torch.nn as nn

### RMSNorm instead of LayerNorm

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.emb_dim = emb_dim
        # only 1 parameter
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()
    
    def forward(self, x):
        # RMS -> Root Mean Square
        # x^2 -> mean -> root
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        # normalize input by this
        return (x_normed * self.weight).to(dtype=x.dtype)

In [4]:
# verify

torch.manual_seed(123)
example = torch.randn(2,3,4)
rms_norm = RMSNorm(emb_dim=example.shape[-1])
rms_norm_pt = torch.nn.RMSNorm(normalized_shape=example.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example), rms_norm_pt(example))

### SiLU instead of GELU

In [5]:
class SiLU(nn.Module):
    def __init__(self):
        super(SiLU, self).__init__()
    
    def forward(self, x):
        return x * torch.sigmoid(x)

In [6]:
# verify

silu = SiLU()
assert torch.allclose(silu(example), torch.nn.functional.silu(example))

### Update FF - SwiGLU (Gating)

In [8]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # cfg['dtype'] will allow loading in lower precision format
        self.fc1 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc2 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['emb_dim'], dtype=cfg['dtype'], bias=False)
        self.silu = SiLU()
    
    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)