# LLaMA From Scratch

**References**
- *Coding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm: [Youtube Video](https://youtu.be/oM4VmoabDAI?si=JtlNl00nZeIOkWxx), [Code](https://github.com/hkproj/pytorch-llama)*
- *LLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU: [Youtube Video](https://youtu.be/Mn_9W1nCFLo?si=4xJy4OzpPX5YxGqx)*
- 

## Imports

In [2]:
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# LLaMA Model

**LLaMA 1**
![LLaMA 1 Parameters](images/llama-1-parameters.png)

**LLaMA 2**
![LLaMA 2 Parameters](images/llama-2-parameters.png)

**LLaMA 3**



### Model Arguments



In [4]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    # * Unlike the og transformer, we don't need to have the same q, k, v values in LLaMA
    n_heads: int = 32  # number of heads for the queries
    n_kv_heads: Optional[int] = None  # Number of heads for the keys and values
    vocab_size: int = -1  # will be set when we load the tokenizer
    # * since grouped query attention heads are reduced, 
    # * the number of params in the FFN is increased to keep the total number of parameters the same
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5  # epsilon for layer norm

    # needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048
    
    device: str = None

## Rotary Positional Embedding

### Precompute Theta Posistional Frequencies

Below are the steps involved in precomputing theta positional frequencies:

![Precompute Theta Posistional Frequencies Steps](images/theta-pos-freq-steps.png)


In [5]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    # theta 10000.0 is the default value in the paper
    # As written in the paragraph 3.2.2 of the paper
    # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
    assert head_dim % 2 == 0, "Dimension must be even since rotary embedding can't be applied to odd."

    # Build the theta parameter
    # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ..., dim/2]
    theta_numerator = torch.arange(0, head_dim, 2).float()  # (head_dim / 2)
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)  # (dim / 2)
    # construct the positions (the "m" parameter)
    m = torch.arange(seq_len, device=device)  # (seq_len)
    # Multiply each theta by each position using the outer product.
    freqs = torch.outer(m, theta).float()  # (seq_len), outer_product*(head_dim/2) -> (seq_len,head_dim/2)
    # we can compute complex numbers in the polar form c = R*exp(m*theta), where R=1 as follow:
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

### Rotary Embeddings

In [6]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device:str):
    # seperate the last dimension pairs of 2 values, representing the real & imaginary parts of the complex number
    # Two consecutive values will become a single complex number
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))  # (B,seq_len,H,head_dim) -> (B,seq_len,H,head_dim/2)
    # reshape the freqs_complex tensor to match the shape of the x_complex tensor.
    # So we need to add the batch dimension and the head dimension.
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)  # (seq_len,head_dim/2) -> (1,seq_len,1,head_dim/2)
    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # Which results in the rotation of the complex number as shown in the Figure 1 of the paper.
    x_rotated = x_complex * freqs_complex  # (B,seq_len,H,head_dim/2)*(1,seq_len,1,head_dim/2) -> (B,seq_len,H,head_dim/2)
    # convert the complex number back to the real number
    x_out = torch.view_as_real(x_rotated)  # (B,seq_len,H,head_dim/2) -> (B,seq_len,H,head_dim/2,2)
    x_out = x_out.reshape(*x.shape)  # (B,seq_len,H,head_dim/2,2) -> (B,seq_len,H,head_dim)
    return x_out.type_as(x).to(device)

## Root Mean Square Normalization

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        # The gamma parameter
        self.weight = nn.Parameter(torch.ones(dim))
    
    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # rsqrt: 1 / sqrt(x)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)  # (B,seq_len,dim)*(B,seq_len,1) -> (B,seq_len,dim)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.weight * self._norm(x.float()).type_as(x)  # (dim)*(B,seq_len,dim) -> (B,seq_len,dim)

## Feed Forward

In [None]:
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        hidden_dim = 4 * args.dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

        self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        swish = F.silu(self.w1(x))  # (B,seq_len,dim) -> (B,seq_len,hidden_dim)
        x_v = self.w3(x)  # (B,seq_len,dim) -> (B,seq_len,hidden_dim)
        x = swish * x_v  # (B,seq_len,hidden_dim)*(B,seq_len,hidden_dim) -> (B,seq_len,hidden_dim)
        x = self.w2(x)  # (B,seq_len,hidden_dim) -> (B,seq_len,dim)
        return x

## Self-Attention

### Repeat KV

In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, None, :]  # (B, seq_len, n_kv_heads, 1, head_dim)
        .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)  # (B, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)  # (B, seq_len, n_kv_heads * n_rep, head_dim)
    )

### Self-Attention

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        # indicates the number of heads for the keys and values
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # indicates the number of heads for the queries
        self.n_heads_q = args.n_heads
        # indicates how many times the keys and values should be repeated
        self.n_rep = self.n_heads_q // self.n_kv_heads
        # indicates the dimension of each head, i.e the part of the embedding that each head will be responsible for
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
    
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        batch_size, seq_len, _ = x.shape  # (B, 1, dim)
        xq = self.wq(x)  # (B, 1, dim) -> (B, 1, H_Q * head_dim)
        xk = self.wk(x)  # (B, 1, dim) -> (B, 1, H_KV * head_dim)
        xv = self.wv(x)  # (B, 1, dim) -> (B, 1, H_KV * head_dim)

        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)  # (B, 1, H_Q * head_dim) -> (B, 1, H_Q, head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)  # (B, 1, H_KV * head_dim) -> (B, 1, H_KV, head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)  # (B, 1, H_KV * head_dim) -> (B, 1, H_KV, head_dim)

        xq = apply_rotary_embeddings(xq, freqs_complex, x.device)  # (B, 1, H_Q, head_dim) -> (B, 1, H_Q, head_dim)
        xk = apply_rotary_embeddings(xk, freqs_complex, x.device)  # (B, 1, H_KV, head_dim) -> (B, 1, H_KV, head_dim)

        # replace the entry in the cache
        self.cache_k[:batch_size, start_pos:start_pos+seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos+seq_len] = xv

        keys = self.cache_k[:batch_size, :start_pos+seq_len]  #  (B, seq_len_kv, H_KV, head_dim)
        values = self.cache_v[:batch_size, :start_pos+seq_len]  # (B, seq_len_kv, H_KV, head_dim)

        # since every group of Q shares the same K & V heads, just repeat the K & V heads for every Q in the same group.
        keys = repeat_kv(keys, self.n_rep)  # (B, seq_len_kv, H_KV, head_dim) -> (B, seq_len_kv, H_Q, head_dim)
        values = repeat_kv(values, self.n_rep)  # (B, seq_len_kv, H_KV, head_dim) -> (B, seq_len_kv, H_Q, head_dim)

        xq = xq.transpose(1, 2)  # (B, 1, H_Q, head_dim) -> (B, H_Q, 1, head_dim)
        keys = keys.transpose(1, 2)  # (B, seq_len_kv, H_Q, head_dim) -> (B, H_Q, seq_len_kv, head_dim)
        values = values.transpose(1, 2)  # (B, seq_len_kv, H_Q, head_dim) -> (B, H_Q, seq_len_kv, head_dim)

        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)  # (B,H_Q,1,head_dim) @ (B,H_Q,head_dim,seq_len_kv) -> (B,H_Q,1,seq_len_kv)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)  # (B,H_Q,1,seq_len_kv) -> (B,H_Q,1,seq_len_kv)

        output = torch.matmul(scores, values)  # (B,H_Q,1,seq_len) @ (B,H_Q,seq_len_kv,head_dim) -> (B,H_Q,1,head_dim)
        output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))  # (B,H_Q,1,head_dim) -> (B,1,H_Q,head_dim) -> (B,1,dim)
        return self.wo(output)  # (B,1,dim) -> (B,1,dim)


## Encoder Block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        self.n_head = args.n_heads
        self.dim = args.dim
        self.head_dim = self.dim // self.n_head

        self.attention = SelfAttention(args)
        self.feed_forward = FeedForward(args)

        # Normalization BEFORE the attention block
        self.attention_norm = RMSNorm(self.dim, eps=args.norm_eps)
        # Normalization BEFORE the feed-forward block
        self.ffn_norm = RMSNorm(self.dim, eps=args.norm_eps)
    
    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor) -> torch.Tensor:
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_complex)  # (B,seq_len,dim) + (B,seq_len,dim) -> (B,seq_len,dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))  # (B,seq_len,dim) + (B,seq_len,dim) -> (B,seq_len,dim)
        return out

## Transformer

![Transformer vs LLaMA](images/Transformer-vs-LLaMA.png)

In [None]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        assert args.vocab_size != -1, "vocab_size must be set"

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers  # represents Nx in the figure above: 32 layers
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))
        
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len*2, device=self.args.device)
    
    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, seq_len = tokens.shape  # (B, seq_len)
        assert seq_len == 1, "Only one token at a time can be processed."

        h = self.tok_embeddings(tokens)  # (B, seq_len) -> (B, seq_len, dim)
        # retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
        freqs_complex = self.freqs_complex[start_pos:start_pos+seq_len]

        # consequently apply all the encoder layers
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float()
        return output