In [1]:
import torch
import torch.nn as nn

In [3]:
class FeedForward_Gemma(nn.Module):
    def __init__(self, cfg):
        super().__init__()

        self.fc1 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc2 = nn.Linear(cfg['emb_dim'], cfg['hidden_dim'], dtype=cfg['dtype'], bias=False)
        self.fc3 = nn.Linear(cfg['hidden_dim'], cfg['emb_dim'], dtype=cfg['dtype'], bias=False)
    
    def forward(self, x):
        x_fc1 = self.fc1(x)
        x_fc2 = self.fc2(x)

        # GELU instead of SiLU
        x = nn.functional.gelu(x_fc1, approximate='tanh') * x_fc2
        return self.fc3(x)

In [4]:
class RMSNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-6, bias=False):
        super().__init__()

        self.eps = eps
        # initialize scale to 0 instead of 1
        self.scale = nn.Parameter(torch.zeros(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim)) if bias else None
    
    def forward(self, x):
        input_dtype = x.dtype
        
        x_f = x.float()
        # RMSNorm
        var = x_f.pow(2).mean(dim=-1, keepdim=True)
        x_norm = x_f * torch.rsqrt(var + self.eps)
        
        # scale by (1+w) instead of by w
        out = x_norm * (1.0 + self.scale.float())
        if self.shift is not None:
            out = out + self.shift.float()
        
        return out.to(input_dtype)

In [5]:
from deps.other_components import precompute_rope_params, compute_rope

In [6]:
class GroupedQueryAttention_Gemma(nn.Module):
    def __init__(self, d_in, num_heads, num_kv_groups, head_dim=None,
                query_pre_attn_scalar=None, dtype=None):
        super().__init__()
        assert num_heads % num_kv_groups == 0, 'num_heads must be divisible by num_kv_groups'

        self.num_heads = num_heads
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        if head_dim is None:
            assert d_in % num_heads == 0, 'd_in must be divisible by num_heads if head_dim is not provided'
            # makes d_out == d_in
            head_dim = d_in // num_heads
        
        self.head_dim = head_dim
        self.d_out = num_heads * head_dim

        self.W_q = nn.Linear(d_in, self.d_out, bias=False, dtype=dtype)
        self.W_k = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
        self.W_v = nn.Linear(d_in, num_kv_groups * head_dim, bias=False, dtype=dtype)
        
        self.out_proj = nn.Linear(self.d_out, d_in, bias=False, dtype=dtype)

        # qk norm
        self.q_norm = RMSNorm(head_dim, eps=1e-6)
        self.k_norm = RMSNorm(head_dim, eps=1e-6)

        if query_pre_attn_scalar is not None:
            self.scaling = query_pre_attn_scalar**-0.5
        else:
            self.scaling = head_dim**-0.5
    
    def forward(self, x, mask, cos, sin):
        b, num_tokens, _ = x.shape

        queries = self.W_q(x)  # (b, n, h*d)
        keys = self.W_k(x)  # (b, n, g*d)
        values = self.W_v(x)  # (b, n, g*d)

        # divide last dim
        # (b, n, h, d)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        # (b, n, g, d)
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)
        # (b, n, g, d)
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)

        # (b, h, n, d)
        queries = queries.transpose(1,2)
        # (b, g, n, d)
        keys = keys.transpose(1,2)
        # (b, g, n, d)
        values = values.transpose(1,2)

        queries = self.q_norm(queries)
        keys = self.q_norm(keys)

        queries = compute_rope(queries, cos, sin)
        keys = compute_rope(keys, cos, sin)

        # make last dim match num_heads just like in queries
        # (b, h, n, d)
        keys = keys.repeat_interleave(self.group_size, dim=1)
        # (b, h, n, d)
        values = values.repeat_interleave(self.group_size, dim=1)

        # scaling can be different than the usual 1/sqrt(d)
        queries = queries * self.scaling

        # (b, h, n, d) x (b, h, d, n) -> (b, h, n, n)
        attn_scores = queries @ keys.transpose(2,3)
        attn_scores = attn_scores.masked_fill(mask, -torch.inf)
        # (b, h, n, n)
        attn_weights = torch.softmax(attn_scores, dim=-1)

        # (b, h, n, n) x (b, h, n, d) -> (b, h, n, d)
        # (b, h, n, d) -> (b, n, h, d)
        context = (attn_weights @ values).transpose(1,2)
        # (b, n, h*d)
        context = context.reshape(b, num_tokens, self.d_out)
        
        return self.out_proj(context)

In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg, attn_type):
        super().__init__()
        
        self.attn_type = attn_type
        self.attn = GroupedQueryAttention_Gemma(
            d_in=cfg['emb_dim'],
            num_heads=cfg['n_heads'],
            num_kv_groups=cfg['n_kv_groups'],
            head_dim=cfg['head_dim'],
            query_pre_attn_scalar=cfg['query_pre_attn_scalar'],
            dtype=cfg['dtype'],
        )

        self.ff = FeedForward_Gemma(cfg)
        
        self.input_norm = RMSNorm(cfg['emb_dim'], eps=1e-6)
        self.post_attn_norm = RMSNorm(cfg['emb_dim'], eps=1e-6)
        self.pre_ff_norm = RMSNorm(cfg['emb_dim'], eps=1e-6)
        self.post_ff_norm = RMSNorm(cfg['emb_dim'], eps=1e-6)
    
    def forward(
            self, x,
            mask_global, mask_local,
            cos_global, cos_local,
            sin_global, sin_local,
            ):
        shortcut = x
        x = self.input_norm(x)
        
        if self.attn_type == 'sliding':
            attn_mask = mask_local
            cos = cos_local
            sin = sin_local
        else:
            attn_mask = mask_global
            cos = cos_global
            sin = sin_global
        
        x_attn = self.attn(x, attn_mask, cos, sin)
        x_attn = self.post_attn_norm(x)
        x = shortcut + x_attn

        shortcut = x
        x_ffn = self.pre_ff_norm(x)
        x_ffn = self.ff(x_ffn)
        x_ffn = self.post_ff_norm(x)
        x = shortcut = x_ffn
        
        return x
