In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from llama_modules import compute_rope, FeedForward, RMSNorm

In [None]:
def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):
    assert head_dim % 2 == 0, "Embedding dimension should be even"
    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2).float() / head_dim))

    # TODO: if freq_config is not None:
        # do rope scaling
    positions = torch.arange(context_length)

    angles = positions[:, None] * inv_freq[None, :]  # [context_length, head_dim/2]
    angles = torch.cat([angles, angles], dim=1)  # [context_length, head_dim]
    cos = torch.cos(angles)
    sin = torch.sin(angles)

    return cos, sin

In [None]:
# Take from torchtune
# def rope_scaling()

In [None]:
llama_3_context_len = 8192 # 4192 for llama 2
llama_3_theta_base = 500_000 # 10K for llama 2

In [None]:
batch_size = 2
num_heads = 4
head_dim = 16

cos, sin = precompute_rope_params(
    head_dim=head_dim,
    theta_base=llama_3_theta_base,
    context_length=llama_3_context_len,
    freq_config=None
)


In [None]:
q = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)
k = torch.randn(batch_size, num_heads, llama_3_context_len, head_dim)
q_rot = compute_rope(q, cos, sin)
k_rot = compute_rope(k, cos, sin)
q_rot.shape, k_rot.shape, q.shape, k.shape

In [None]:
class SharedBuffers:
    _buffers = {}

    @staticmethod
    def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):
        key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)
        if key not in SharedBuffers._buffers:
            mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
            cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)
            if dtype is not None:
                cos = cos.to(dtype)
                sin = sin.to(dtype)

            SharedBuffers._buffers[key] = (mask, cos, sin)

        return SharedBuffers._buffers[key]

In [None]:
buf = SharedBuffers.get_buffers(12, 96, 10_000, None)
buf[0].shape

In [None]:
class GroupedQueryAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, num_heads, num_kv_groups, rope_base=10_000, rope_config=None, dtype=None):
        super().__init__()
        assert d_out % num_heads == 0
        assert num_heads % num_kv_groups == 0

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
        self.num_kv_groups = num_kv_groups
        self.group_size = num_heads // num_kv_groups

        # not grouped
        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
        self.out_proj = nn.Linear(d_in, d_out, bias=False, dtype=dtype)

        mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)
        self.register_buffer("mask", mask)
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        queries = self.W_query(x) # [b, num_tokens, d_out]
        keys = self.W_key(x) # [b, num_tokens, num_kv_groups * head_dim]
        values = self.W_value(x) # [b, num_tokens, num_kv_groups * head_dim]

        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) # [b, num_heads, num_tokens, head_dim]
        keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) # [b, num_kv_groups, num_tokens, head_dim]
        values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2) # [b, num_kv_groups, num_tokens, head_dim]

        # Apply ROPE
        keys = compute_rope(keys, self.cos, self.sin)
        queries = compute_rope(queries, self.cos, self.sin)

        # [b, num_heads, num_tokens, head_dim]
        keys = keys.repeat_interleave(self.group_size, dim=1)
        values = values.repeat_interleave(self.group_size, dim=1)

        # [b, num_heads, num_tokens, head_dim] [b, num_heads, head_dim, num_tokens] -> [b, num_heads, num_tokens, num_tokens]
        attn_scores = torch.matmul(queries, keys.transpose(2, 3))
        attn_scores = attn_scores / self.head_dim ** 0.5
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        # [b, num_heads, num_tokens, num_tokens]
        attn_weights = torch.softmax(attn_scores, dim=-1)
        # [b, num_heads, num_tokens, head_dim]
        context_vec = (attn_weights @ values).transpose(1, 2).reshape(b, num_tokens, self.d_out)
        return self.out_proj(context_vec)

In [None]:
a = torch.arange(16).view(1, 4, 4)
a.shape

In [None]:
b = a.repeat_interleave(2, dim=1)
print(a)
print(b)

In [None]:
embed_dim = 4096
num_heads = 32
max_context_length = 8192
context_len = 3000
batch_size = 2

example_batch = torch.randn(batch_size, context_len, embed_dim)
print(example_batch.shape)

grouped_query_attention = GroupedQueryAttention(
    d_in=embed_dim,
    d_out=embed_dim,
    context_length=max_context_length,
    num_heads=num_heads,
    num_kv_groups=8,
    rope_base=llama_3_theta_base
)

print(grouped_query_attention(example_batch).shape)
print(grouped_query_attention.W_key.weight.shape)
print(grouped_query_attention.W_query.weight.shape)

In [None]:
del grouped_query_attention
del example_batch

In [None]:
class SublayerConnection(nn.Module):
    """
    Apply RMSNorm and residual connection.
    """

    def __init__(self, size):
        super().__init__()
        self.norm = RMSNorm(size)

    def forward(self, x, sublayer):
        return x + sublayer(self.norm(x))

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att = GroupedQueryAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            num_kv_groups=cfg["n_kv_groups"],
            rope_base=cfg['rope_base'],
            rope_config=cfg['rope_freq'],
            dtype=cfg['dtype']
        )
        self.ff = FeedForward(cfg)
        self.sublayer1 = SublayerConnection(cfg["emb_dim"])
        self.sublayer2 = SublayerConnection(cfg["emb_dim"])

    def forward(self, x):
        # might have some interesting consequences when we load weights
        # attention block
        x = self.sublayer1(x, self.att)
        # FF block
        x = self.sublayer2(x, self.ff)
        return x

In [None]:
class Llama3Model(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(
            cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"]
        )

        self.trf_blocks = nn.Sequential(
            *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]
        )
        self.final_norm = RMSNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(
            cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]
        )

    def forward(self, in_idx, targets=None):
        x = self.tok_emb(in_idx)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1))
        return logits, loss

In [None]:
LLAMA3_CONFIG_8B = {
    "vocab_size": 128_256,
    "context_length": 8192,
    "emb_dim": 4096,
    "n_heads": 32,
    "n_layers": 32,
    "hidden_dim": 14_336,
    "n_kv_groups": 8,
    "rope_base": 500_000,
    "rope_freq": None,
    "dtype": torch.bfloat16
}

In [None]:
model = Llama3Model(LLAMA3_CONFIG_8B)

In [None]:
# model

In [None]:
print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)
print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)
print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)

In [None]:
def get_model_params(model):
    return sum(p.numel() for p in model.parameters())

In [None]:
get_model_params(model)

In [None]:
from model_utils import total_memory_size

In [None]:
total_memory_size(model)

In [None]:
total_memory_size(model, torch.bfloat16)

In [None]:
device = torch.device('cuda')
model.to(device);