In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# -------------------------------
# LLAMA32 Configuration
# -------------------------------
LLAMA32_CONFIG = {
    "vocab_size": 128_256,
    "context_length": 131_072,  # max context
    "emb_dim": 3072,
    "n_heads": 24,
    "n_layers": 28,
    "hidden_dim": 8192,
    "n_kv_groups": 8,
    "rope_base": 500_000.0,
    "dtype": torch.bfloat16,
    "rope_freq": {
        "factor": 32.0,
        "low_freq_factor": 1.0,
        "high_freq_factor": 4.0,
        "original_context_length": 8192,
    }
}

# -------------------------------
# Rotary Positional Encoding (RoPE)
# -------------------------------
def apply_rope(x, base=500_000.0):
    batch, seq_len, dim = x.shape
    half_dim = dim // 2
    freq_seq = torch.arange(half_dim, dtype=x.dtype, device=x.device)
    freq_seq = 1.0 / (base ** (freq_seq / half_dim))
    positions = torch.arange(seq_len, dtype=x.dtype, device=x.device)
    angles = torch.einsum("i,j->ij", positions, freq_seq)
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    x1, x2 = x[..., :half_dim], x[..., half_dim:]
    return torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)

# -------------------------------
# Multi-Head Attention with KV cache and causal masking
# -------------------------------
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.emb_dim = config["emb_dim"]
        self.n_heads = config["n_heads"]
        self.head_dim = self.emb_dim // self.n_heads
        self.qkv_proj = nn.Linear(self.emb_dim, self.emb_dim * 3)
        self.out_proj = nn.Linear(self.emb_dim, self.emb_dim)

    def forward(self, x, kv_cache=None):
        batch, seq_len, emb_dim = x.shape

        # Apply Rotary Positional Encoding
        x = apply_rope(x, base=LLAMA32_CONFIG["rope_base"])

        # Compute QKV
        qkv = self.qkv_proj(x)
        qkv = qkv.view(batch, seq_len, 3, self.n_heads, self.head_dim)
        q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]

        # Move heads forward
        q = q.transpose(1, 2)  # [batch, heads, seq_len, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Append to KV cache
        if kv_cache is not None:
            k = torch.cat([kv_cache["k"], k], dim=2)
            v = torch.cat([kv_cache["v"], v], dim=2)
        kv_cache = {"k": k, "v": v}

        # Causal mask
        seq_len_total = k.size(2)
        mask = torch.tril(torch.ones(seq_len_total, seq_len_total, device=x.device)).unsqueeze(0).unsqueeze(0)
        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_scores = attn_scores.masked_fill(mask[:, :, -seq_len:, :] == 0, float('-inf'))

        # Attention probabilities
        attn_probs = F.softmax(attn_scores, dim=-1)
        out = torch.matmul(attn_probs, v)

        # Merge heads
        out = out.transpose(1, 2).reshape(batch, seq_len, emb_dim)
        out = self.out_proj(out)
        return out, kv_cache

# -------------------------------
# Feedforward network
# -------------------------------
class FeedForward(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.fc1 = nn.Linear(config["emb_dim"], config["hidden_dim"])
        self.activation = nn.GELU()
        self.fc2 = nn.Linear(config["hidden_dim"], config["emb_dim"])

    def forward(self, x):
        return self.fc2(self.activation(self.fc1(x)))

# -------------------------------
# Transformer block
# -------------------------------
class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn = MultiHeadAttention(config)
        self.ffn = FeedForward(config)
        self.ln1 = nn.LayerNorm(config["emb_dim"], eps=1e-5)
        self.ln2 = nn.LayerNorm(config["emb_dim"], eps=1e-5)

    def forward(self, x, kv_cache=None):
        attn_out, kv_cache = self.attn(self.ln1(x), kv_cache)
        x = x + attn_out
        x = x + self.ffn(self.ln2(x))
        return x, kv_cache

# -------------------------------
# Full LLAMA32 Model
# -------------------------------
class LLAMA32(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.token_emb = nn.Embedding(config["vocab_size"], config["emb_dim"])
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config["n_layers"])])
        self.ln_final = nn.LayerNorm(config["emb_dim"], eps=1e-5)
        self.head = nn.Linear(config["emb_dim"], config["vocab_size"], bias=False)

    def forward(self, input_ids, kv_caches=None):
        x = self.token_emb(input_ids)
        new_caches = []
        for i, block in enumerate(self.blocks):
            kv_cache = None if kv_caches is None else kv_caches[i]
            x, cache = block(x, kv_cache)
            new_caches.append(cache)
        x = self.ln_final(x)
        logits = self.head(x)
        return logits, new_caches

# -------------------------------
# Testing
# -------------------------------
if __name__ == "__main__":
    # Device setup
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Reduce context length for testing
    test_context_length = 64
    cfg = LLAMA32_CONFIG.copy()
    cfg["context_length"] = test_context_length

    # Initialize model
    model = LLAMA32(cfg).to(device=device, dtype=cfg["dtype"])
    model.eval()

    # Dummy input
    batch_size = 2
    seq_len = test_context_length
    dummy_input_ids = torch.randint(0, cfg["vocab_size"], (batch_size, seq_len), device=device)

    # Forward pass
    with torch.no_grad():
        logits, caches = model(dummy_input_ids)

    # Output shapes
    print("Logits shape:", logits.shape)  # [batch, seq_len, vocab_size]
    print("KV cache for first block k shape:", caches[0]["k"].shape)  # [batch, heads, seq_len, head_dim]


Using device: cuda
Logits shape: torch.Size([2, 64, 128256])
KV cache for first block k shape: torch.Size([2, 24, 64, 128])
