# LLaMA From Scratch

**References**
- *Coding LLaMA 2 from scratch in PyTorch - KV Cache, Grouped Query Attention, Rotary PE, RMSNorm: [Youtube Video](https://youtu.be/oM4VmoabDAI?si=JtlNl00nZeIOkWxx), [Code](https://github.com/hkproj/pytorch-llama)*
- *LLaMA explained: KV-Cache, Rotary Positional Embedding, RMS Norm, Grouped Query Attention, SwiGLU: [Youtube Video](https://youtu.be/Mn_9W1nCFLo?si=4xJy4OzpPX5YxGqx)*
- *RoFormer: Enhanced Transformer with Rotary Position Embedding: [Paper](https://arxiv.org/abs/2104.09864)*
- *Root Mean Square Layer Normalization: [Paper](https://arxiv.org/abs/1910.07467)*
- *GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints: [Paper](https://arxiv.org/abs/2305.13245)*

## Imports

In [2]:
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
import time

# LLaMA Model

### LLaMA Family

**LLaMA 1**
![LLaMA 1 Parameters](images/llama-1-parameters.png)

**LLaMA 2**
![LLaMA 2 Parameters](images/llama-2-parameters.png)

**LLaMA 3**



### Model Arguments



In [4]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    # * Unlike the og transformer, we don't need to have the same q, k, v values in LLaMA
    n_heads: int = 32  # number of heads for the queries
    n_kv_heads: Optional[int] = None  # Number of heads for the keys and values
    vocab_size: int = -1  # will be set when we load the tokenizer
    # * since grouped query attention heads are reduced,
    # * the number of params in the FFN is increased to keep the total number of parameters the same
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5  # epsilon for layer norm

    # needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = None

## Rotary Positional Embedding

### Precompute Theta Posistional Frequencies

Below are the steps involved in precomputing theta positional frequencies:

![Precompute Theta Posistional Frequencies Steps](images/theta-pos-freq-steps.png)

In [5]:
def precompute_theta_pos_frequencies(
    head_dim: int, seq_len: int, device: str, theta: float = 10000.0
):
    # theta 10000.0 is the default value in the paper
    # As written in the paragraph 3.2.2 of the paper
    # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...]
    assert (
        head_dim % 2 == 0
    ), "Dimension must be even since rotary embedding can't be applied to odd."

    # Build the theta parameter
    # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ..., dim/2]
    theta_numerator = torch.arange(0, head_dim, 2).float()  # (head_dim / 2)
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)  # (dim / 2)
    # construct the positions (the "m" parameter)
    m = torch.arange(seq_len, device=device)  # (seq_len)
    # Multiply each theta by each position using the outer product.
    # (seq_len), outer_product*(head_dim/2) -> (seq_len,head_dim/2)
    freqs = torch.outer(m, theta).float()
    # we can compute complex numbers in the polar form c = R*exp(m*theta), where R=1 as follow:
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

### Rotary Embeddings

The Steps in calculating the Rotary Embedding:
![The Steps in calculating the Rotary Embedding](images/rotary-embedding-steps.png)

Figure 1: Implementation of Rotary Position Embedding(RoPE):
![Implementation of Rotary Position Embedding](images/implementation-of-rope.png)


In [6]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    # * OP 1 & 2 >>
    # seperate the last dimension pairs of 2 values, representing the real & imaginary parts of the complex number
    # two consecutive values will become a single complex number
    # (B,seq_len,H,head_dim) -> (B,seq_len,H,head_dim/2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    # reshape the freqs_complex tensor to match the shape of the x_complex tensor.
    # So we need to add the batch dimension and the head dimension.
    # (seq_len,head_dim/2) -> (1,seq_len,1,head_dim/2)
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    # * OP 3 >>
    # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor
    # which results in the rotation of the complex number as shown in the Figure 1 of the paper.
    # (B,seq_len,H,head_dim/2)*(1,seq_len,1,head_dim/2) -> (B,seq_len,H,head_dim/2)
    x_rotated = x_complex * freqs_complex
    # * OP 4 >> convert the complex number back to the real number
    # (B,seq_len,H,head_dim/2) -> (B,seq_len,H,head_dim/2,2)
    x_out = torch.view_as_real(x_rotated)
    # * OP 5 >> Flattening to the shape of the original tensor
    # (B,seq_len,H,head_dim/2,2) -> (B,seq_len,H,head_dim)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

## Root Mean Square Normalization

LayerNorm works because of its re-centering and re-scaling invariance property. Re-centering enables the model to be insensitive to shift noises on both inputs and weights, and re-scaling keeps the output representations intact when both inputs and weights are randomly scaled. 
RMS Normalizaiton paper hypothesize that the re-scaling invariance is the reason for success of LayerNorm, rather than re-centering invariance and they propose RMSNorm which only focuses on re-scaling invariance and regularizes the summed inputs simply according to the root mean square (RMS) statistic:

$$
\bar{a}_i= \frac{a_i}{RMS(a)} g_i \\
\text{where} \ RMS(a) = \sqrt{\frac{1}{n} \sum_{i=1}^n a_{i}^2}
$$

Intuitively, RMSNorm simplifies LayerNorm by totally removing the mean statistic at the cost of sacrificing the invariance that mean normalization affords.

In [None]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))  # The gamma parameter

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        # rsqrt: 1 / sqrt(x)
        # (B,seq_len,dim)*(B,seq_len,1) -> (B,seq_len,dim)
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (dim)*(B,seq_len,dim) -> (B,seq_len,dim)
        return self.weight * self._norm(x.float()).type_as(x)

## Feed Forward

### SwiGLU Activation Function

In [None]:
class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        hidden_dim = 4 * args.dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        # Round the hidden_dim to the nearest multiple of the multiple_of parameter
        hidden_dim = args.multiple_of * (
            (hidden_dim + args.multiple_of - 1) // args.multiple_of
        )

        self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        swish = F.silu(self.w1(x))  # (B, seq_len, dim) -> (B, seq_len, hidden_dim)
        x_v = self.w3(x)  # (B, seq_len, dim) -> (B, seq_len, hidden_dim)
        # (B, seq_len, hidden_dim) * (B, seq_len, hidden_dim) -> (B, seq_len, hidden_dim)
        x = swish * x_v
        x = self.w2(x)  # (B, seq_len, hidden_dim) -> (B, seq_len, dim)
        return x

## Attention

### Repeat KV Cache

Self-Attention during Next Token Prediction Task at Inference T=1:

![Self-Attention during Next Token Prediction Task at T1](images/Self-Attention-during-NTP-Task-T1.png)

Self-Attention during Next Token Prediction Task at Inference T=4:

![Self-Attention during Next Token Prediction Task at T4](images/Self-Attention-during-NTP-Task-T4.png)

Where KV Cache is useful:

![Where KV Cache is useful](images/where-kv-cache-come-in.png)

Self-Attention with KV-Cache at Inference T=1:

![Self-Attention with KV-Cache at T1](images/Self-Attention-with-KV-Cache-T1.png)

Self-Attention with KV-Cache at Inference T=4:

![Self-Attention with KV-Cache at T4](images/Self-Attention-with-KV-Cache-T4.png)

In [None]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, None, :]  # (B, seq_len, n_kv_heads, 1, head_dim)
        .expand(
            batch_size, seq_len, n_kv_heads, n_rep, head_dim
        )  # (B, seq_len, n_kv_heads, n_rep, head_dim)
        .reshape(
            batch_size, seq_len, n_kv_heads * n_rep, head_dim
        )  # (B, seq_len, n_kv_heads * n_rep, head_dim)
    )

### Grouped Query Attention

#### Comparing different attention algorithms

**Vanilla batched Multi-Head Attention**


**Batched Multi-Head Attention with KV cache**


**<span style="color:red">Multi-Query</span> Attention with KV cache**



#### Grouped Multi-Query Attention

Grouped Multi-Query Attention is a compromise between Multi-Head Attention and Multi-Query Attention:

![Grouped Multi-Query Attention compared](images/Grouped-Multi-Query-Attention-compared.png)


In [None]:
class Attention(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        # indicates the number of heads for the keys and values
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        # indicates the number of heads for the queries
        self.n_heads_q = args.n_heads
        # indicates how many times the keys and values should be repeated
        self.n_rep = self.n_heads_q // self.n_kv_heads
        # indicates the dimension of each head, i.e the part of the embedding that each head will be responsible for
        self.head_dim = args.dim // args.n_heads

        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

        self.cache_k = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)
        )
        self.cache_v = torch.zeros(
            (args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)
        )

    def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        batch_size, seq_len, _ = x.shape  # (B, 1, dim)
        xq = self.wq(x)  # (B, 1, dim) -> (B, 1, H_Q * head_dim)
        xk = self.wk(x)  # (B, 1, dim) -> (B, 1, H_KV * head_dim)
        xv = self.wv(x)  # (B, 1, dim) -> (B, 1, H_KV * head_dim)

        # (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim)
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        # Size is the same for xk & xv: (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

        # Size doesn't change for xq & zk: (B, 1, H_Q, head_dim) -> (B, 1, H_Q, head_dim)
        xq = apply_rotary_embeddings(xq, freqs_complex, x.device)
        xk = apply_rotary_embeddings(xk, freqs_complex, x.device)

        # replace the entry in the cache for this token
        self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk
        self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv

        # retrieve all the cached keys and values so far
        # Size is the same for keys & values: (B, seq_len_kv, H_KV, head_dim)
        keys = self.cache_k[:batch_size, : start_pos + seq_len]
        values = self.cache_v[:batch_size, : start_pos + seq_len]

        # since every group of Q shares the same K & V heads,
        # just repeat the K & V heads for every Q in the same group.
        # Doesn't look like grouped query attention is being done here since only 70B LLaMA has this feature.
        # So this is just multi-head attention.
        # Size is the same for keys & values:
        # (B, seq_len_kv, H_KV, head_dim) -> (B, seq_len_kv, H_Q, head_dim)
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)

        xq = xq.transpose(1, 2)  # (B, 1, H_Q, head_dim) -> (B, H_Q, 1, head_dim)
        # Size is the same for keys & values:
        # (B, seq_len_kv, H_Q, head_dim) -> (B, H_Q, seq_len_kv, head_dim)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # (B, H_Q, 1, head_dim) @ (B, H_Q, head_dim, seq_len_kv) -> (B, H_Q, 1, seq_len_kv)
        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
        # (B, H_Q, 1, seq_len_kv) -> (B, H_Q, 1, seq_len_kv)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        # (B, H_Q, 1, seq_len) @ (B, H_Q, seq_len_kv, head_dim) -> (B, H_Q, 1, head_dim)
        output = torch.matmul(scores, values)
        # (B, H_Q, 1, head_dim) -> (B, 1, H_Q, head_dim) -> (B, 1, dim)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
        return self.wo(output)  # (B, 1, dim)

## Encoder Block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()

        self.n_head = args.n_heads
        self.dim = args.dim
        self.head_dim = self.dim // self.n_head

        self.attention = Attention(args)
        self.feed_forward = FeedForward(args)

        # Normalization before the attention block
        self.attention_norm = RMSNorm(self.dim, eps=args.norm_eps)
        # Normalization before the feed-forward block
        self.ffn_norm = RMSNorm(self.dim, eps=args.norm_eps)

    def forward(
        self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor
    ) -> torch.Tensor:
        # (B,seq_len,dim) + (B,seq_len,dim) -> (B,seq_len,dim)
        h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_complex)
        # (B,seq_len,dim) + (B,seq_len,dim) -> (B,seq_len,dim)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

## Transformer

![Transformer vs LLaMA](images/Transformer-vs-LLaMA.png)

In [None]:
class Transformer(nn.Module):
    def __init__(self, args: ModelArgs) -> None:
        super().__init__()
        assert args.vocab_size != -1, "vocab_size must be set"

        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers  # represents Nx in the figure above: 32 layers
        self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)

        self.layers = nn.ModuleList()
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

        self.freqs_complex = precompute_theta_pos_frequencies(
            self.args.dim // self.args.n_heads,
            self.args.max_seq_len * 2,
            device=self.args.device,
        )

    def forward(self, tokens: torch.Tensor, start_pos: int):
        batch_size, seq_len = tokens.shape  # (B, seq_len)
        assert seq_len == 1, "Only one token at a time can be processed."

        h = self.tok_embeddings(tokens)  # (B, seq_len) -> (B, seq_len, dim)
        # retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
        freqs_complex = self.freqs_complex[start_pos : start_pos + seq_len]

        # consequently apply all the encoder layers
        for layer in self.layers:
            h = layer(h, start_pos, freqs_complex)
        h = self.norm(h)
        output = self.output(h).float()
        return output

# Inference

## LLaMA Inference

In [None]:
class LLaMA:
    def __init__(
        self,
        model: Transformer,
        tokenizer: SentencePieceProcessor,
        model_args: ModelArgs,
    ) -> None:
        self.model = model
        self.tokenizer = tokenizer
        self.args = model_args

    @staticmethod
    def build(
        checkpoints_dir: str,
        tokenizer_path: str,
        load_model: bool,
        max_seq_len: int,
        max_batch_size: int,
        device: str,
    ):
        prev_time = time.time()
        if load_model:
            checkpoints = sorted(Path(checkpoints_dir).glob("*.pth"))
            assert len(checkpoints) > 0, f"No checkpoints found in {checkpoints_dir}"
            ckpt_path = checkpoints[0]
            print(f"Loading model from checkpoint: {ckpt_path}")
            checkpoint = torch.load(ckpt_path, map_location="cpu")
            print(f"Loaded checkpoint in {time.time() - prev_time:.2f} seconds")
            prev_time
        with open(Path(checkpoints_dir) / "params.json", "r") as f:
            params = json.loads(f.read())

        model_args: ModelArgs = ModelArgs(
            max_seq_len=max_seq_len,
            max_batch_size=max_batch_size,
            device=device,
            **params,
        )

        tokenizer = SentencePieceProcessor()
        tokenizer.load(tokenizer_path)
        model_args.vocab_size = tokenizer.vocab_size()

        if device == "cuda":
            torch.set_default_tensor_type(torch.cuda.HalfTensor)
        else:
            torch.set_default_tensor_type(torch.BFloat16Tensor)

        model = Transformer(model_args).to(device)

        if load_model:
            # The only unmatched key in the checkpoint is rope.freqs. Remove it
            del checkpoint["rope.freqs"]
            model.load_state_dict(checkpoint, strict=True)
            print(f"Loaded state dict in {time.time() - prev_time:.2f} seconds")

        return LLaMA(model, tokenizer, model_args)

    def text_completion(
        self,
        prompts: list[str],
        temperature: float = 0.6,
        top_p: float = 0.9,
        max_gen_len: Optional[int] = None,
    ):
        if max_gen_len is None:
            max_gen_len = self.args.max_seq_len - 1

        # convert each prompt into tokens
        prompt_tokens = [
            self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False)
            for prompt in prompts
        ]
        # Make sure the batch size is not too large
        batch_size = len(prompt_tokens)
        assert (
            batch_size <= self.args.max_batch_size
        ), f"Batch size {batch_size} must be less than or equal to the Max batch size: {self.args.max_batch_size}"
        max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
        # Make sure the prompt length is not larger than the max sequence length
        assert (
            max_prompt_len <= self.args.max_seq_len
        ), f"Prompt length {max_prompt_len} must be less than or equal to the Max sequence length: {self.args.max_seq_len}"
        total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)

        # create the list that will contain the generated tokens, along with the initial prompt tokens
        pad_id = self.tokenizer.pad_id()
        tokens = torch.full(
            (batch_size, total_len), pad_id, dtype=torch.long, device=device
        )
        for k, t in enumerate(prompt_tokens):
            # populate the initial tokens with the prompt tokens
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)

        eos_reached = torch.tensor([False] * batch_size, device=device)
        # True if the token is a prompt token, False otherwise
        prompt_tokens_mask = tokens != pad_id
        cur_iterator = tqdm(range(1, total_len), desc="Generating tokens...")
        for cur_pos in cur_iterator:
            with torch.no_grad():
                logits = self.model.forward(tokens[:, cur_pos - 1 : cur_pos], cur_pos)
            if temperature > 0:
                # The temperature is applied before the softmax
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = self._sample_top_p(probs, top_p)
            else:
                # greedily select the token with the max probability
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # only replace token if it is a padding token
            next_token = torch.where(
                prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token
            )
            tokens[:, cur_pos] = next_token
            # EOS is reached only if we found an EOS token for a padding position
            eos_reached != (
                ~prompt_tokens_mask[:, cur_pos] & (next_token == self.tokenizer.eos_id)
            )
            if all(eos_reached):
                break

        out_tokens = []
        out_text = []
        for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
            # cut to the EOS token, if present
            if self.tokenizer.eos_id in current_prompt_tokens:
                eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)
                current_prompt_tokens = current_prompt_tokens[:eos_idx]
            out_tokens.append(current_prompt_tokens)
            out_text.append(self.tokenizer.decode(current_prompt_tokens))
        return (out_tokens, out_text)

    def _sample_top_p(self, probs: torch.Tensor, p: float):
        probs_sort, probs_idx = torch.sort(
            probs, dim=-1, descending=True
        )  # (B, vocab_size)
        probs_sum = torch.cumsum(probs_sort, dim=-1)  # (B, vocab_size)
        # (substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
        mask = probs_sum - probs_sort > p  # (B, vocab_size)
        # zero out all the probabilities of tokens that are not selected by the top_p
        probs_sort[mask] = 0.0
        # redistribute the probabilities so that they sum up to 1.
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        # sample a token (its index) from the top p distribution
        next_token = torch.multinomial(probs_sort, num_samples=1)
        # get the token position in the vocabulary corresponding to the sampled index
        next_token = torch.gather(probs_idx, -1, next_token)
        return next_token

## Inference Run

In [None]:
torch.manual_seed(42)

allow_cuda = False
device = "cuda" if torch.cuda.is_available() and allow_cuda else "cpu"

prompts = [
    "Simply put, the theory of relativity states that ",
    "If Google was an Italian company founded in Milan, it would",
    # Few shot promt
    """Translate English to French:
    sea otter => loutre de mer
    peppermint => menthe poivrée
    plush girafe => girafe peluche
    cheese =>""",
    # Zero shot prompt
    """Tell me if the following person is actually a Jedi night disguised as human:
    Name: Mukesh Mithrakumar
    Decision: 
    """,
]

model = LLaMA.build(
    checkpoints_dir="models/llama-2-7b",
    tokenizer_path="models/tokenizer.model",
    load_model=True,
    max_seq_len=1024,
    max_batch_size=len(prompts),
    device=device,
)

out_tokens, out_texts = model.text_completion(prompts, max_gen_len=64)
assert len(out_texts) == len(prompts)
for i in range(len(out_texts)):
    print(f"{out_texts[i]}\n'-'*50")