# Llama-3.2-1B-Instruct

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# https://huggingface.co/meta-llama/Llama-3.2-1B-Instruct
model_name = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to("cuda")
n_params = sum(p.numel() for p in hf_model.parameters())
print(f"Params: {n_params / (10**9):.2f}B\n")

tokenizer = AutoTokenizer.from_pretrained(model_name)
special_tokens = {
    "bot": "<|begin_of_text|>",
    "eot": "<|eot_id|>",
    "sh": "<|start_header_id|>",
    "eh": "<|end_header_id|>",
    "eos": "<|end_of_text|>"
}
for token in special_tokens.values():
    print(token, tokenizer.convert_tokens_to_ids(token))
print()

# https://www.llama.com/docs/model-cards-and-prompt-formats/meta-llama-3/
prompt = f"""{special_tokens['sh']}system{special_tokens['eh']}
You are a helpful AI assistant that is an expert in Deep Learning and a detailed and clear teacher.{special_tokens['eot']}
{special_tokens['sh']}user{special_tokens['eh']}
Can you explain the Llama architecture?{special_tokens['eot']}
{special_tokens['sh']}assistant{special_tokens['eh']}"""
print(prompt, "\n")

tokens = tokenizer(prompt, return_tensors="pt").to("cuda")
print(tokens["input_ids"], "\n")

generated = hf_model.generate(
    **tokens,
    max_length=len(tokens) + 156, 
    num_return_sequences=2,
    temperature=0.7,
    top_p=0.9,
    eos_token_id=tokenizer.convert_tokens_to_ids(special_tokens["eot"])
)
for generation in tokenizer.batch_decode(generated):
    print("\n", "-"*80)
    print(generation)

  from .autonotebook import tqdm as notebook_tqdm


Params: 1.24B



Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


<|begin_of_text|> 128000
<|eot_id|> 128009
<|start_header_id|> 128006
<|end_header_id|> 128007
<|end_of_text|> 128001

<|start_header_id|>system<|end_header_id|>
You are a helpful AI assistant that is an expert in Deep Learning and a detailed and clear teacher.<|eot_id|>
<|start_header_id|>user<|end_header_id|>
Can you explain the Llama architecture?<|eot_id|>
<|start_header_id|>assistant<|end_header_id|> 

tensor([[128000, 128006,   9125, 128007,    198,   2675,    527,    264,  11190,
          15592,  18328,    430,    374,    459,   6335,    304,  18682,  21579,
            323,    264,  11944,    323,   2867,  11326,     13, 128009,    198,
         128006,    882, 128007,    198,   6854,    499,  10552,    279,    445,
          81101,  18112,     30, 128009,    198, 128006,  78191, 128007]],
       device='cuda:0') 


 --------------------------------------------------------------------------------
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a helpful AI 

In [2]:
print(hf_model)
print(hf_model.config)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [3]:
from dataclasses import dataclass

@dataclass
class RopeConfig:
    base_freq: float = 500_000.0
    
    scale_factor: float = 32.0
    high_freq_factor: float = 4.0
    low_freq_factor: float = 1.0
    og_max_pos: int = 8_192
    
@dataclass
class LlamaConfig:
    vocab_size: int = 128_256
    n_hidden_layers: int = 16
    d_model: int = 2_048
    d_hidden: int = 8_192
    rms_norm_eps: float = 1e-5

    max_pos: int = 131_072
    rope_config: RopeConfig = RopeConfig()

    d_head: int = 64
    n_heads: int = 32
    n_kv_heads: int = 8

llama_config = LlamaConfig()
print(llama_config)

LlamaConfig(vocab_size=128256, n_hidden_layers=16, d_model=2048, d_hidden=8192, rms_norm_eps=1e-05, max_pos=131072, rope_config=RopeConfig(base_freq=500000.0, scale_factor=32.0, high_freq_factor=4.0, low_freq_factor=1.0, og_max_pos=8192), d_head=64, n_heads=32, n_kv_heads=8)


In [4]:
from torch import nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    """
    SwiGLU: https://arxiv.org/pdf/2002.05202v1
    SiLU aka Swish: https://pytorch.org/docs/stable/generated/torch.nn.functional.silu.html
    """
    def __init__(self, d_model: int, d_hidden: int):
        super().__init__()
        self.up_proj = nn.Linear(d_model, d_hidden, bias=False)
        self.gate_proj = nn.Linear(d_model, d_hidden, bias=False)
        self.down_proj = nn.Linear(d_hidden, d_model, bias=False)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = F.silu(self.gate_proj(x))
        z = self.up_proj(x)
        return self.down_proj(gate * z)

hf_swiglu = hf_model.model.layers[0].mlp
my_swiglu = SwiGLU(llama_config.d_model, llama_config.d_hidden).to(device="cuda", dtype=torch.bfloat16)
for name, param in hf_swiglu.state_dict().items():
    print(name, param.shape)
    my_swiglu.state_dict()[name].copy_(param)

def check_diff(x: torch.Tensor, y: torch.Tensor, name: str = "Tensors", atol: float = 1e-5) -> None:
    print(f"{name} max diff: {(x - y).abs().max()}")
    assert torch.allclose(x, y, atol=atol), f"{name} not allclose"

B, T, C = 8, 128, llama_config.d_model
rand_x = torch.randn((B, T, C), device="cuda", dtype=torch.bfloat16)
check_diff(hf_swiglu(rand_x), my_swiglu(rand_x), name="SwiGLU")

gate_proj.weight torch.Size([8192, 2048])
up_proj.weight torch.Size([8192, 2048])
down_proj.weight torch.Size([2048, 8192])
SwiGLU max diff: 0.0


In [5]:
import math
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb

class RoPE(nn.Module):
    """
    Rotary Positional Encoding: https://arxiv.org/pdf/2104.09864
    HuggingFace Positional Encoding explanation: https://huggingface.co/blog/designing-positional-encoding
    HuggingFace Llama Positional Encoding: https://github.com/huggingface/transformers/blob/41b9b92b52215bed472c9a534a06abbc3a9a95cd/src/transformers/modeling_rope_utils.py#L322
    """
    def __init__(self, llama_config: LlamaConfig) -> None:
        super().__init__()
        rope_config = llama_config.rope_config
        freqs = 1.0 / rope_config.base_freq ** (torch.arange(0, llama_config.d_head, 2) / llama_config.d_head) 
        wavelens = 2 * math.pi / freqs

        low_freq_wavelen = rope_config.og_max_pos / rope_config.low_freq_factor
        freqs[wavelens > low_freq_wavelen] /= rope_config.scale_factor

        smooth_factor = ((rope_config.og_max_pos / wavelens) - rope_config.low_freq_factor) / (rope_config.high_freq_factor - rope_config.low_freq_factor)
        smoothed_freqs = ((1.0 - smooth_factor) * freqs / rope_config.scale_factor) + (smooth_factor * freqs)
        high_freq_wavelen = rope_config.og_max_pos / rope_config.high_freq_factor
        is_medium_freq = (wavelens <= low_freq_wavelen) & (wavelens >= high_freq_wavelen)
        freqs = torch.where(is_medium_freq, smoothed_freqs, freqs)

        pos = torch.arange(llama_config.max_pos).unsqueeze(1)
        pthetas = pos * freqs  # (max_pos, d_head / 2)
        self.register_buffer("sin", pthetas.sin().repeat(1, 2))
        self.register_buffer("cos", pthetas.cos().repeat(1, 2))

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        x1 = x[..., :x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2:]
        return torch.cat((-x2, x1), dim=-1)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.shape[-2]
        return (self.cos[None, None, :T] * x) + (self.sin[None, None, :T] * self.rotate_half(x))

Q = torch.rand(B, llama_config.n_heads, T, llama_config.d_head, device="cuda", dtype=torch.bfloat16)
K = torch.rand_like(Q)
affin_no_rope = Q @ K.transpose(-1, -2)

rope = RoPE(llama_config).to("cuda", dtype=torch.bfloat16)
rope_q, rope_k = rope(Q), rope(K)
print(rope_q.shape, rope_k.shape)

# RoPE is a relative positional encoding, affinities on the diagonal should be the same.
# This check does not work well for lower precision (bfloat16).
# affin_rope = rope_q @ rope_k.transpose(-1, -2)
# affin_rope_diags = torch.diagonal(affin_rope, dim1=-2, dim2=-1)
# affin_no_rope_diags = torch.diagonal(affin_no_rope, dim1=-2, dim2=-1)
# check_diff(affin_rope_diags, affin_no_rope_diags, name="RoPE diags")

hf_rope = hf_model.model.rotary_emb
position_ids = torch.arange(T).expand(B, T).to("cuda")
cos, sin = hf_rope(Q, position_ids)
check_diff(cos, rope.cos[:T], name="Cos")
check_diff(sin, rope.sin[:T], name="Sin")
hf_rope_q, hf_rope_k = apply_rotary_pos_emb(Q, K, cos, sin)
check_diff(hf_rope_q, rope_q, name="RoPE Q")
check_diff(hf_rope_k, rope_k, name="RoPE K")

torch.Size([8, 32, 128, 64]) torch.Size([8, 32, 128, 64])
Cos max diff: 0.0
Sin max diff: 0.0
RoPE Q max diff: 0.0
RoPE K max diff: 0.0


In [6]:
class GroupedQueryAttention(nn.Module):
    """
    Causal grouped-query attention.
    GQA: https://arxiv.org/pdf/2305.13245
    """
    def __init__(self, d_model: int, n_heads: int, n_kv_heads: int, rope: RoPE):
        super().__init__()
        self.rope = rope
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        assert n_heads % n_kv_heads == 0, "n_heads must be divisible by n_kv_heads"
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_size = d_model // n_heads
        self.k_proj = nn.Linear(d_model, n_kv_heads * self.head_size, bias=False)
        self.v_proj = nn.Linear(d_model, n_kv_heads * self.head_size, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        B, T, C = x.shape
        q = q.view(B, T, self.n_heads, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_kv_heads, self.head_size).transpose(1, 2)

        q = self.rope(q)
        k = self.rope(k)

        out = F.scaled_dot_product_attention(q, k, v, is_causal=True, enable_gqa=True)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.o_proj(out)

hf_gqa = hf_model.model.layers[0].self_attn
print(hf_gqa)
my_gqa = GroupedQueryAttention(2048, 32, 8, rope).to("cuda", dtype=torch.bfloat16)
for (name, param) in hf_gqa.state_dict().items():
    print(name, param.shape)
    my_gqa.state_dict()[name].copy_(param)

hf_attn_out = hf_gqa(rand_x, (cos, sin), None)[0]
my_attn_out = my_gqa(rand_x)
check_diff(hf_attn_out, my_attn_out, name="Attention")

LlamaAttention(
  (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
  (k_proj): Linear(in_features=2048, out_features=512, bias=False)
  (v_proj): Linear(in_features=2048, out_features=512, bias=False)
  (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
)
q_proj.weight torch.Size([2048, 2048])
k_proj.weight torch.Size([512, 2048])
v_proj.weight torch.Size([512, 2048])
o_proj.weight torch.Size([2048, 2048])
Attention max diff: 0.0


In [7]:
class TransformerBlock(nn.Module):
    def __init__(self, config: LlamaConfig, rope: RoPE) -> None:
        super().__init__()
        self.self_attn = GroupedQueryAttention(config.d_model, config.n_heads, config.n_kv_heads, rope)
        self.mlp = SwiGLU(config.d_model, config.d_hidden)
        self.input_layernorm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)
        self.post_attention_layernorm = nn.RMSNorm(config.d_model, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.self_attn(self.input_layernorm(x))
        x = x + self.mlp(self.post_attention_layernorm(x))
        return x

my_block = TransformerBlock(llama_config, rope).to("cuda", dtype=torch.bfloat16)
hf_block = hf_model.model.layers[0]
print(hf_block)
for (name, param) in hf_block.state_dict().items():
    print(name, param.shape)
    my_block.state_dict()[name].copy_(param)

hf_block_out = hf_block(rand_x, position_embeddings=(cos, sin))[0]
my_block_out = my_block(rand_x)
check_diff(hf_block_out, my_block_out, name="Block")

LlamaDecoderLayer(
  (self_attn): LlamaAttention(
    (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
    (k_proj): Linear(in_features=2048, out_features=512, bias=False)
    (v_proj): Linear(in_features=2048, out_features=512, bias=False)
    (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
  )
  (mlp): LlamaMLP(
    (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
    (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
    (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
  (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
)
self_attn.q_proj.weight torch.Size([2048, 2048])
self_attn.k_proj.weight torch.Size([512, 2048])
self_attn.v_proj.weight torch.Size([512, 2048])
self_attn.o_proj.weight torch.Size([2048, 2048])
mlp.gate_proj.weight torch.Size([8192, 2048])
mlp.up_proj.weight torch.Size([8192, 2048])
mlp.

In [8]:
class Llama(nn.Module):
    def __init__(self, config: LlamaConfig) -> None:
        super().__init__()
        self.rope = RoPE(config)
        self.model = nn.ModuleDict({
            "embed_tokens": nn.Embedding(config.vocab_size, config.d_model),
            "layers": nn.ModuleList([TransformerBlock(config, self.rope) for i in range(config.n_hidden_layers)]),
            "norm": nn.RMSNorm(config.d_model, eps=config.rms_norm_eps),
        })
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.lm_head.weight = self.model.embed_tokens.weight

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model.embed_tokens(x)
        for layer in self.model.layers:
            x = layer(x)
        logits = self.lm_head(self.model.norm(x))
        return logits
    
my_llama = Llama(llama_config).to("cuda", dtype=torch.bfloat16)
for (name, param) in hf_model.state_dict().items():
    print(name, param.shape)
    my_llama.state_dict()[name].copy_(param)

hf_logits = hf_model(tokens["input_ids"]).logits
my_logits = my_llama(tokens["input_ids"])
check_diff(hf_logits, my_logits[-1], name="Logits")  # TODO: logits allclose fails.

model.embed_tokens.weight torch.Size([128256, 2048])
model.layers.0.self_attn.q_proj.weight torch.Size([2048, 2048])
model.layers.0.self_attn.k_proj.weight torch.Size([512, 2048])
model.layers.0.self_attn.v_proj.weight torch.Size([512, 2048])
model.layers.0.self_attn.o_proj.weight torch.Size([2048, 2048])
model.layers.0.mlp.gate_proj.weight torch.Size([8192, 2048])
model.layers.0.mlp.up_proj.weight torch.Size([8192, 2048])
model.layers.0.mlp.down_proj.weight torch.Size([2048, 8192])
model.layers.0.input_layernorm.weight torch.Size([2048])
model.layers.0.post_attention_layernorm.weight torch.Size([2048])
model.layers.1.self_attn.q_proj.weight torch.Size([2048, 2048])
model.layers.1.self_attn.k_proj.weight torch.Size([512, 2048])
model.layers.1.self_attn.v_proj.weight torch.Size([512, 2048])
model.layers.1.self_attn.o_proj.weight torch.Size([2048, 2048])
model.layers.1.mlp.gate_proj.weight torch.Size([8192, 2048])
model.layers.1.mlp.up_proj.weight torch.Size([8192, 2048])
model.layers.1.

AssertionError: Logits not allclose