# LLaMA From Scratch

**References**
- *Coding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm: [Youtube Video](https://youtu.be/oM4VmoabDAI?si=JtlNl00nZeIOkWxx), [Code](https://github.com/hkproj/pytorch-llama)*
- *LLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU: [Youtube Video](https://youtu.be/Mn_9W1nCFLo?si=4xJy4OzpPX5YxGqx)*
- 

## Imports

In [2]:
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

# LLaMA Model

**LLaMA 1**
![LLaMA 1 Parameters](images/llama-1-parameters.png)

**LLaMA 2**
![LLaMA 2 Parameters](images/llama-2-parameters.png)

**LLaMA 2**



### Model Arguments



In [4]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    # * Unlike the og transformer, we don't need to have the same q, k, v values in LLaMA
    n_heads: int = 32  # number of heads for the queries
    n_kv_heads: Optional[int] = None  # Number of heads for the keys and values
    vocab_size: int = -1  # will be set when we load the tokenizer
    # * since grouped query attention heads are reduced, 
    # * the number of params in the FFN is increased to keep the total number of parameters the same
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5  # epsilon for layer norm

    # needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048
    
    device: str = None

## Rotary Positional Embedding

### Precompute Theta Posistional Frequencies

In [5]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    # As written in the paragraph 3.2.2 of the paper
    # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
    assert head_dim % 2 == 0, "Dimension must be divisible by 2."

    # Build the theta parameter
    # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ..., dim/2]
    theta_numerator = torch.arange(0, head_dim, 2).float()  # (head_dim / 2)
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)  # (dim / 2)
    # construct the positions (the "m" parameter)
    m = torch.arange(seq_len, device=device)  # (seq_len)
    # Multiply each theta by each position using the outer product.
    freqs = torch.outer(m, theta).float()  # (seq_len), outer_product* (head_dim/2) -> (seq_len,head_dim/2)
    # we can compute complex numbers in the polar form c = R*exp(m*theta), where R=1 as follow:
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

### Rotary Embeddings

In [6]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device:str):
    # seperate the last dimension pairs of 2 values, representing the real & imaginary parts of the complex number
    # Two consecutive values will become a single complex number
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))  # (B,seq_len,H,head_dim) -> (B,seq_len,H,head_dim/2)
    # reshape the freqs_complex tensor to match the shape of the x_complex tensor.
    # So we need to add the batch dimension and the head dimension.
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)  # (seq_len,head_dim/2) -> (1,seq_len,1,head_dim/2)
    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # Which results in the rotation of the complex number as shown in the Figure 1 of the paper.
    x_rotated = x_complex * freqs_complex  # (B,seq_len,H,head_dim/2)*(1,seq_len,1,head_dim/2) -> (B,seq_len,H,head_dim/2)
    # convert the complex number back to the real number
    x_out = torch.view_as_real(x_rotated)  # (B,seq_len,H,head_dim/2) -> (B,seq_len,H,head_dim/2,2)
    x_out = x_out.reshape(*x.shape)  # (B,seq_len,H,head_dim/2,2) -> (B,seq_len,H,head_dim)
    return x_out.type_as(x).to(device)

## Transformer

![Transformer vs LLaMA](images/Transformer-vs-LLaMA.png)

In [None]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        assert args.vocab_size != -1, "vocab_size must be set"

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers  # represents Nx in the figure above: 32 layers
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))
        
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len*2, device=self.args.device)
    
    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, seq_len = tokens.shape  # (B, seq_len)
        assert seq_len == 1, "Only one token at a time can be processed."

        h = self.tok_embeddings(tokens)  # (B, seq_len) -> (B, seq_len, dim)
        # retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
        freqs_complex = self.freqs_complex[start_pos:start_pos+seq_len]

        # consequently apply all the encoder layers
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float()
        return output