In [2]:
import sentencepiece
import torch
import torch.nn as nn

### RMSNorm instead of LayerNorm

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.emb_dim = emb_dim
        # only 1 parameter
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()
    
    def forward(self, x):
        # RMS -> Root Mean Square
        # x^2 -> mean -> root
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        # normalize input by this
        return (x_normed * self.weight).to(dtype=x.dtype)

In [4]:
# verify

torch.manual_seed(123)
example = torch.randn(2,3,4)
rms_norm = RMSNorm(emb_dim=example.shape[-1])
rms_norm_pt = torch.nn.RMSNorm(normalized_shape=example.shape[-1], eps=1e-5)

assert torch.allclose(rms_norm(example), rms_norm_pt(example))

### SiLU instead of GELU

In [5]:
class SiLU(nn.Module):
    def __init__(self):
        super(SiLU, self).__init__()
    
    def forward(self, x):
        return x * torch.sigmoid(x)

In [6]:
# verify

silu = SiLU()
assert torch.allclose(silu(example), torch.nn.functional.silu(example))

### Update FF - SwiGLU (Gating)

In [8]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # cfg['dtype'] will allow loading in lower precision format
        self.fc1 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc2 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['emb_dim'], dtype=cfg['dtype'], bias=False)
        self.silu = SiLU()
    
    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)

### Rotary Positional Embedding

In [9]:
def precompute_rope_params(head_dim, theta_base=10_000, context_len=4096):
    assert head_dim % 2 == 0, 'Head dimension must be even'

    p = torch.arange(0, head_dim, 2)
    p = p[:head_dim//2].float()
    p = p / head_dim
    inv_freq = 1.0 / (theta_base**p)

    positions = torch.arange(context_len)
    angles = positions[:, None] * inv_freq[None, :]
    # (context, head_dim // 2) -> (context, head_dim)
    angles = torch.cat([angles, angles], dim=1)

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, 'Head dimension must be even'

    # split into 2 halves
    x1 = x[..., :head_dim//2]
    x2 = x[..., head_dim//2:]

    # (1, 1, seq_len, head_dim)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x*cos) + (rotated*sin)

    return x_rotated.to(dtype=x.dtype)

In [10]:
batch_size = 2
context_len = 5
num_heads = 4
head_dim = 16

cos, sin = precompute_rope_params(head_dim, context_len=context_len)

torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
keys = torch.randn(batch_size, num_heads, context_len, head_dim)

queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)