In [2]:
import torch
import torch.nn as nn

### Reuse

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-5):
        super().__init__()

        self.eps = eps
        self.emb_dim = emb_dim
        # only 1 parameter
        self.weight = nn.Parameter(torch.ones(emb_dim)).float()
    
    def forward(self, x):
        # RMS -> Root Mean Square
        # x^2 -> mean -> root
        means = x.pow(2).mean(dim=-1, keepdim=True)
        x_normed = x * torch.rsqrt(means + self.eps)
        # normalize input by this
        return (x_normed * self.weight).to(dtype=x.dtype)

In [4]:
class SiLU(nn.Module):
    def __init__(self):
        super(SiLU, self).__init__()
    
    def forward(self, x):
        return x * torch.sigmoid(x)

In [5]:
class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        # cfg['dtype'] will allow loading in lower precision format
        self.fc1 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc2 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['emb_dim'], dtype=cfg['dtype'], bias=False)
        self.silu = SiLU()
    
    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)
        x = self.silu(x_fc1) * x_fc2
        return self.fc3(x)

In [6]:
def precompute_rope_params(head_dim, theta_base=10_000, context_len=4096):
    assert head_dim % 2 == 0, 'Head dimension must be even'

    p = torch.arange(0, head_dim, 2)
    p = p[:head_dim//2].float()
    p = p / head_dim
    inv_freq = 1.0 / (theta_base**p)

    positions = torch.arange(context_len)
    angles = positions[:, None] * inv_freq[None, :]
    # (context, head_dim // 2) -> (context, head_dim)
    angles = torch.cat([angles, angles], dim=1)

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

def compute_rope(x, cos, sin):
    batch_size, num_heads, seq_len, head_dim = x.shape
    assert head_dim % 2 == 0, 'Head dimension must be even'

    # split into 2 halves
    x1 = x[..., :head_dim//2]
    x2 = x[..., head_dim//2:]

    # (1, 1, seq_len, head_dim)
    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)
    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)

    rotated = torch.cat((-x2, x1), dim=-1)
    x_rotated = (x*cos) + (rotated*sin)

    return x_rotated.to(dtype=x.dtype)

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, num_heads, dtype=None):
        super().__init__()

        assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # no bias
        self.W_q = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.W_k = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.W_v = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)
        
        self.register_buffer('mask', torch.triu(torch.ones(context_len, context_len), diagonal=1))

        cos, sin = precompute_rope_params(head_dim=self.head_dim, context_len=context_len)
        self.register_buffer('cos', cos)
        self.register_buffer('sin', sin)
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape

        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)

        # (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        queries = queries.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)

        queries = compute_rope(queries, self.cos, self.sin)
        keys = compute_rope(keys, self.cos, self.sin)

        attn_scores = queries @ keys.transpose(2,3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)

        return context_vec


### Modified RoPE

In [8]:
# higher theta, from 10K to 500K
# higher = decay frequencies/rotation angles more slowly
# higher dimensions = larger angles

def precompute_rope_params(head_dim, theta_base=10_000, context_len=4096, freq_config=None):
    assert head_dim % 2 == 0, 'Head dim must be even'

    p = torch.arange(0, head_dim, 2)
    p = p[:head_dim//2].float()
    p = p / head_dim
    inv_freq = 1.0 / theta_base**p

    # for Llama 3.1 and 3.2
    if freq_config is not None:
        low_freq_wavelen = freq_config['original_context_len'] \
                            / freq_config['low_freq_factor']
        high_freq_wavelen = freq_config['original_context_len'] \
                            / freq_config['high_freq_factor']
        
        wavelen = 2*torch.pi / inv_freq
        inv_freq_llama = torch.where(
            wavelen > low_freq_wavelen,
            inv_freq / freq_config['factor'],
            inv_freq,
        )

        smooth_factor = (freq_config['original_context_len'] \
                        / wavelen - freq_config['low_freq_factor']) \
                        /(freq_config['high_freq_factor'] - freq_config['low_freq_factor'])
        smoothed_inv_freq = (1-smooth_factor) * (inv_freq / freq_config['factor']) \
                        + smooth_factor * inv_freq
        
        is_med_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)
        inv_freq_llama = torch.where(is_med_freq, smoothed_inv_freq, inv_freq_llama)
        inv_freq = inv_freq_llama
    
    positions = torch.arange(context_len)
    angles = positions[:,None] * inv_freq[None,:]
    angles = torch.cat([angles, angles], dim=1)

    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

In [9]:
# difference is just these
llama_2_context_len = 4096
llama_3_context_len = 8192

llama_2_theta_base = 10_000
llama_3_theta_base = 500_000

batch_size = 2
num_heads = 4
head_dim = 16

cos, sin = precompute_rope_params(head_dim, llama_3_theta_base, llama_3_context_len)

torch.manual_seed(123)
queries = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)
keys = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)

queries_rot = compute_rope(queries, cos, sin)
keys_rot = compute_rope(keys, cos, sin)

### Grouped-Query Attention
Each head has a unique query, but multiple heads can share keys and values. <br>
More compute and parameter efficient.

In [10]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, num_heads,
                num_kv_groups, dtype=None):
        super().__init__()

        # e.g. d_out = 1024, heads = 16
        # each head = 1024/16 = 64 dim
        assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'
        # new
        # e.g. groups = 4, heads = 16
        # each k/v attends to = 16/4 = 4 queries
        assert num_heads % num_kv_groups == 0, 'num_heads must be divisible by num_kv_groups'

        # old
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        # new
        # instead of d_out, we have num_kv_groups * head_dim
        self.W_k = nn.Linear(d_in, num_kv_groups*self.head_dim, bias=False, dtype=dtype)
        self.W_v = nn.Linear(d_in, num_kv_groups*self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        # old
        self.W_q = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

    def forward(self, x, mask=None, cos=None, sin=None):
        # new: mask are passed via forward() instead of self.mask
        # Llama 3.1 can have 128K sequences, so computing 128000x128000 mask is memory-intensive

        # old
        b, num_tokens, d_in = x.shape

        queries = self.W_q(x)  # (b, num_tokens, d_out)
        keys = self.W_k(x)  # (b, num_tokens, num_kv_groups * head_dim)
        values = self.W_v(x)  # (b, num_tokens, num_kv_groups * head_dim)

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # new
        # (b, n, g*hd) -> (b, n, g, hd)
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        # old
        queries = queries.transpose(1,2)  # (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1,2)  # (b, num_kv_groups, num_tokens, head_dim)
        values = values.transpose(1,2)  # (b, num_kv_groups, num_tokens, head_dim)

        # new
        if cos is not None:
            keys = compute_rope(keys, cos, sin)
            queries = compute_rope(queries, cos, sin)
        
        # new
        # make k/v match num_heads)
        # 2nd dim becomes num_heads because group_size = num_heads / num_kv_groups
        # so num_heads = group_size * num_kv_groups
        
        # e.g. [K1, K2]
        # with regular repeat: [K1, K2, K1, K2]
        # with interleave: [K1, K1, K2, K2]
        
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        # old
        attn_scores = queries @ keys.transpose(2,3)

        # new
        if mask is None:
            mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
        
        # old
        attn_scores.masked_fill_(mask, -torch.inf)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        assert keys.shape[-1] == self.head_dim

        context_vec = (attn_weights @ values).transpose(1,2)
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        
        return context_vec

In [None]:
# compare MHA and GQA savings

# reduce values due to memory limitations
batch_size = 1
context_len = 256 #3000
max_context_len = 1024 #8192
emb_dim = 32 #4096
num_heads = 8 #32

example = torch.randn(batch_size, context_len, emb_dim)

mha = MultiHeadAttention(emb_dim, emb_dim, max_context_len, num_heads)
mha(example)

print(f'W_q: {mha.W_q.weight.shape}')
print(f'W_k: {mha.W_k.weight.shape}')
print(f'W_v: {mha.W_v.weight.shape}')

W_q: torch.Size([32, 32])
W_k: torch.Size([32, 32])
W_v: torch.Size([32, 32])


In [12]:
gqa = GroupedQueryAttention(emb_dim, emb_dim, num_heads, num_kv_groups=2)

gqa(example)
print(f'W_q: {gqa.W_q.weight.shape}')
print(f'W_k: {gqa.W_k.weight.shape}')
print(f'W_v: {gqa.W_v.weight.shape}')

W_q: torch.Size([32, 32])
W_k: torch.Size([8, 32])
W_v: torch.Size([8, 32])


In [13]:
print(f'MHA params: {sum(p.numel() for p in mha.parameters()):,}')
print(f'GQA params: {sum(p.numel() for p in gqa.parameters()):,}')

MHA params: 4,096
GQA params: 2,560


In [14]:
del mha
del gqa

### TransformerBlock

In [15]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.attn = GroupedQueryAttention(
            d_in=cfg['emb_dim'],
            d_out=cfg['emb_dim'],
            num_heads=cfg['n_heads'],
            num_kv_groups=cfg['n_kv_groups'],
            dtype=cfg['dtype'],
        )

        self.ff = FeedForward(cfg)
        self.norm1 = RMSNorm(cfg['emb_dim'], eps=1e-5)
        self.norm2 = RMSNorm(cfg['emb_dim'], eps=1e-5)
    
    def forward(self, x, mask=None, cos=None, sin=None):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x.to(torch.bfloat16), mask, cos, sin)
        x = x + shortcut
        
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x.to(torch.bfloat16))
        x = x + shortcut

        return x

### Model

In [16]:
class Llama3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.tok_emb = nn.Embedding(cfg['vocab_size'], cfg['emb_dim'], dtype=cfg['dtype'])
        self.trf_blocks = nn.Sequential(*[
            TransformerBlock(cfg)
            for _ in range(cfg['n_layers'])
        ])
        self.final_norm = RMSNorm(cfg['emb_dim'], eps=1e-5)
        self.out_head = nn.Linear(cfg['emb_dim'], cfg['vocab_size'], bias=False, dtype=cfg['dtype'])

        # new
        cos, sin = precompute_rope_params(
            head_dim=cfg['emb_dim'] // cfg['n_heads'],
            theta_base=cfg['rope_base'],
            context_len=cfg['context_len'],
            freq_config=cfg['rope_freq'],
        )
        self.register_buffer('cos', cos, persistent=False)
        self.register_buffer('sin', sin, persistent=False)

        self.cfg = cfg
    
    def forward(self, in_idx):
        tok_embs = self.tok_emb(in_idx)
        x = tok_embs

        # new
        num_tokens = x.shape[1]
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)

        for blk in self.trf_blocks:
            x = blk(x, mask, self.cos, self.sin)
        
        x = self.final_norm(x)
        logits = self.out_head(x.to(self.cfg['dtype']))
        return logits