# Background reading

* This colab is made by following [Coding LLaMA2 from scratch in Pytorch](https://youtu.be/oM4VmoabDAI?si=wqO0DodGhpvQjZK4)
  * [Slides](https://github.com/hkproj/pytorch-llama/blob/main/Slides.pdf)
  * [Code](https://github.com/hkproj/pytorch-llama)
* Rotary embedding: the following video introduces the topic simply: [rotary](https://youtu.be/o29P0Kpobz0?si=PNAqtmp33uJ2Gozv)
* KV Cache in transformers [KV Cache](https://youtu.be/80bIUggRJf4?si=_rE35Q9BMMA9Ge31)

# Model


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from dataclasses import dataclass
import sentencepiece
import tqdm
from typing import Optional

@dataclass
class ModelArgs:
  dim: int = 4096 # embeddings dimension?
  n_layers: int = 32
  n_heads: int = 32 # Number of heads for the queries
  n_kv_heads: Optional[int] = None # Number of heads for the k and v
  vocab_size: int = -1 # This will be set when we load the tokenizer
  # hidden dimension of the FFN layer its to compensate for the grouped query attention
  multiple_of : int = 256
  ffn_dim_multiplier: Optional[float] = None
  # we will see why we need it
  norm_eps: float = 1e-5
  ## needed for KV Cache
  max_batch_size: int = 32
  max_seq_len: int = 2048

  device: str = None


def _precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float=10000.0) -> torch.Tensor:
  """precomputed theta frequencies for rotary embeddings."""
  # as written in the paper
  assert head_dim % 2 == 0, "head_dim must be divisible by 2 for this to work"
  # Build the theta parameters according to the formula
  # theta_i = 10000 ^ (-2(i-1)/dim) for i = [1, 2, ... dim/2]
  # Shape: (Head_dim /2)
  theta_numerator = torch.arange(0, head_dim, 2).float()
  # Shape: (head_dim/2)
  theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device)
  # construct the posiitons (the "m" parameters)
  # Shape: (Seq_Len)
  m = torch.arrange(seq_len, device=device)
  # Multiply each theta by each position using the outer product.
  # Shape: (Seq_len) outer_product * (Head_dim/2) -> (seq_len, head_dim/2)
  freqs = torch.outer(m, theta).float()
  # we can compute complex numbers in the complex form c= R * exp(i * m * theta), where R = 1 as follows:
  # (seq_len, head_dim/2) -> (seq_len, head_dim/2)
  freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
  return freqs_complex

def apply_rotary_pos_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str) -> torch.Tensor:
  """Apply rotary positional embeddings to a tensor."""
  # (B, Seq_len, H, Head_dim) -> (B, seq_len, H, Head_dim/2)
  x_complex = torch.view_as_complex(x.float().reshape(x.shape[0], x.shape[1], -1, 2))
  # (Seq_len, Head_dim/2 ) -> (1, Seq_len, 1, Head_dim/2)
  freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
  # (B, Seq_len, H, Head_dim/2) * (1, Seq_len, 1, Head_dim/2) = (B, Seq_len, H, Head_dim/2)
  x_rotated = x_complex * freqs_complex
  # (B, Seq_len, H, Head_dim/2) -> (B, Seq_len, H, Head_dim/2, 2)
  x_out = torch.view_as_real(x_rotated)
  # (B, Seq_len, H, Head_dim/2, 2) -> (B, Seq_len, H, Head_dim)
  x_out = x_out.reshape(*x.shape)
  return x_out.type_as(x).to(device)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
  """Repeat the values n_rep times."""
  batch_size, seq_len, n_kv_heads, head_dim = x.shape
  if n_rep == 1:
    return x
  else:
    return (
      # (B, seq_len, N_kv_heads, 1, heads_dim)
      x[:, :, :, None, :].expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim)
      .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim)
  )



class RMSNorm(nn.Module):
  def __init__(self, dim: int, eps: float = 1e-6) -> None:
    super().__init__()
    # epsilon to avoid a division by zero
    self.eps = eps
    # the gamma parameter
    self.weight = nn.Parameter(torch.ones(dim))

  def _norm(self, x: torch.Tensor) -> torch.Tensor:
    # (B, seq_len, dim) * (B, seq_len, 1) -> (B, seq_len, Dim)
    # rsqrt: 1/sqrt(x)
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # (Dim) * (B, Seq_len, Dim) = (B, Seq_len, Dim)
    return self.weight * self._norm(x.float()).type_as(x)

class FeedForward(nn.Module):
  def __init__(self, args: ModelArgs) -> None:
    super().__init__()

    hidden_dim = 4 * args.dim
    hidden_dim = int(2 * hidden_dim / 3)
    if args.ffn_dim_multiplier is not None:
      hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
    # Round the hidden_dim to the nearest multiple of the multiple_of parameter
    hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

    self.w1 = nn.Linear(args.dim, hidden_dim, bias=False)
    self.w2 = nn.Linear(hidden_dim, args.dim, bias=False)
    self.w3 = nn.Linear(args.dim, hidden_dim, bias=False)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    swish = F.silu(self.w1(x))
    x_V = self.w3(x)
    x = swish * x_V
    x = self.w2(x)
    return x

class SelfAttention(nn.Module):
  # Only used for inference
  def __init__(self, args: ModelArgs) -> None:
    super().__init__()

    # code simplified, the parallelisation is removed
    # Number of heads for the key and the value
    self.n_kv_heads = args.n_kv_heads if args.n_kv_heads is not None else args.n_heads
    # number of heads for the query
    self.n_heads_q = args.n_heads
    # how many times the heads of the keys and values should be repeated to match the head of the queries
    self.n_rep = self.n_heads_q // self.n_kv_heads
    # the part of the embedding that will be vizualized by each head
    # the dimension of each head
    self.head_dim = args.dim // args.n_heads

    self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
    self.wk = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
    self.wv = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False)
    self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

    self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
    self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))

  def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor)  -> torch.Tensor:
    batch_size, seq_len, _ = x.shape #(B, 1, dim)
    # Apply the Wq, Wk, and Wv matrices to queries, keys, and values
    # (B, 1, dim) -> (B, 1, H_Q* Head_dim)
    xq = self.wq(x)
    # (B, 1, dim) -> (B, 1, H_KV * Head_dim)
    xk = self.wk(x)
    xv = self.wv(x)

    # Split into the various heads for each Q, K, V
    # (B, 1, H_Q * Head_dim) -> (B, 1, H_Q, Head_dim)
    xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
    # (B, 1, H_KV * Head_dim) -> (B, 1, H_KV, Head_dim)
    xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
    xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)

    # Apply rotary positional encodings on Q, K but not V
    xq = apply_rotary_pos_embeddings(xq, freqs_complex, x.device)
    xk = apply_rotary_pos_embeddings(xk, freqs_complex, x.device)

    # Replace the entry in the cache for this token
    self.cache_k[:batch_size, start_pos:start_pos + seq_len] = xk
    self.cache_v[:batch_size, start_pos:start_pos + seq_len] = xv

    # Retrieve all the cached keys and values for the pos
    # (B, seq_len_kv, , H_KV, Head_dim)
    keys = self.cache_k[:batch_size, 0:start_pos+seq_len]
    values = self.cache_v[:batch_size, 0:start_pos+seq_len]

    # Repeat the heads of the K and V to reach the number of heads of the queries to treat it as a vanilla multihead
    keys = repeat_kv(keys, self.n_rep)
    values = repeat_kv(values, self.n_rep)

    # (B, 1, H_Q, Head_dim) --> (B, H_Q, 1, Head_dim)
    xq = xq.transpose(1, 2)
    keys = keys.transpose(1, 2)
    values = values.transpose(1, 2)

    # (B, H_Q, 1, Head_dim) @ (B, H_Q, HEad_Dim, Seq_len_kv) --> (B, H_Q, 1, Seq_len_kv)
    scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)
    # (B, H_Q, 1, Seq_len) @ (B, H_Q, Seq_len_KV, Head_dim) --> (B, H_Q, 1, Head_dim)
    output = torch.matmul(scores, values)

    # (B, H_Q, 1, Head_dim) -> (B, 1, H_Q, Head_dim) -> (B, 1, Dim)
    output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
    return self.wo(output) # (B, 1, Dim) -> (B, 1, Dim)

class EncoderBlock(nn.Module):
  def __init__(self, args: ModelArgs) -> None:
    super().__init__()
    self.n_heads = args.n_heads
    self.dim = args.dim
    # the embedding has a dimension dim, but each head will see dim/n_heads items per token
    self.head_dim = args.dim // args.n_heads

    self.attention = SelfAttention(args)
    self.feed_forward = FeedForward(args)

    # Normalization before the self attention
    # Look at the slides for visual:
    # https://github.com/hkproj/pytorch-llama/blob/main/Slides.pdf
    self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
    # Normalization Before the feed forward block
    self.ffw_norm = RMSNorm(args.dim, eps=args.norm_eps)

  def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor) -> torch.Tensor:
    # (B, Seq_len, Dim) + (B, Seq_len, Dim) -> (B, Seq_len, Dim)
    rms_norm_before_attention = self.attention_norm(x)
    h = x + self.attention.forward(rms_norm_before_attention, start_pos, freqs_complex)

    rsm_norm_before_ffw = self.ffw_norm(h)
    out = h + self.feed_forward.forward(rsm_norm_before_ffw)
    return out

class Transformer(nn.Module):
  def __init__(self, args: ModelArgs) -> None:
    super().__init__()

    assert args.vocab_size > 0, "Vocab size must be positive"

    self.args = args
    self.vocab_size = args.vocab_size
    self.n_layers = args.n_layers
    self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)

    self.layers = nn.ModuleList()
    for layer in range(args.n_layers):
      self.layers.append(EncoderBlock(args))

    # eps: epsilon is used for the normalization so we are never dividing by zero.
    self.norm = RMSNorm(args.dim, eps=args.norm_eps)
    # the output is always the vocab size since we will be sampling from it.
    # see it as a classification one from vocab_size
    self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    # frequency of the rotary positional embeddings
    head_dim = self.args.dim // self.args.n_heads
    seq_len = self.args.max_seq_len * 2
    self.freqs_complex = _precompute_theta_pos_frequencies(head_dim,
                                                           seq_len=seq_len,
                                                           device=self.args.device)

  def forward(self, tokens: torch.Tensor, start_pos: int):
    # (batch B, seq_len)
    batch_size, seq_len = tokens.shape
    assert seq_len == 1, "Only one token at a time can be processed."

    # (B, Seq_len) -> (B, Seq_len, dim)
    h = self.tok_embeddings(tokens)

    # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
    freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

    # Consecutively apply all the encoder layers blocks
    for layer in self.layers:
      h = layer(h, start_pos, freqs_complex)

    # Normalize the output
    h = self.norm(h)
    output = self.output(h).float()

    return output

# Inference

In [14]:
import time
from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor

class LLaMA:
  def __init__(self, model: Transformer, tokenizer: SentencePieceProcessor, model_args : ModelArgs) -> None:
    self.model = model
    self.tokenizer = tokenizer
    self.model_args = model_args

  @staticmethod
  def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str) -> 'LLaMA':
    start_time = time.time()
    if load_model:
      checkpoints = sorted(Path(checkpoints_dir).glob('*.pth'))
      assert len(checkpoints) > 0, f"No checkpoints found in {checkpoints_dir}"
      ckpt_path = checkpoints[0]
      print(f"Loading model from {ckpt_path}")
      checkpoint = torch.load(ckpt_path, map_location='cpu',weights_only=True)
      print(f"Time to load model: {time.time() - start_time:.2f}")
      start_time = time.time()
    with open(Path(checkpoints_dir) / 'params.json', 'r') as f:
      params = json.loads(f.read())
    model_args : ModelArgs = ModelArgs(
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        device=device,
        **params
    )
    tokenizer = SentencePieceProcessor()
    tokenizer.load(tokenizer_path)
    model_args.vocab_size = tokenizer.vocab_size()

    if device == "cuda":
      torch.set_default_tensor_type(torch.cuda.HalfTensor)
    else:
      torch.set_default_tensor_type(torch.BFloat16Tensor)

    model = Transformer(model_args).to(device)
    if load_model:
      del checkpoint['rope.freqs']
      model.load_state_dict(checkpoint, strict=True)
    print(f"Time to load state dict in: {time.time() - start_time:.2f}")
    return LLaMA(model, tokenizer, model_args)


## Download the model

In order to download the model, you need to acknowledge the usage on HF and then follow these steps:

In [3]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [4]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    A token is already saved on your machine. Run `huggingface-cli whoami` to get more information or `huggingface-cli logout` if you want to log out.
    Setting a new token will erase the existing one.
    To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Enter your token (input will not be visible): 
Add token as git credential? (Y/n) Y
Token is valid (permission: fineG

In [5]:
!huggingface-cli download meta-llama/Llama-2-7b

Fetching 10 files:   0% 0/10 [00:00<?, ?it/s]Downloading 'LICENSE.txt' to '/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/blobs/51089e27e6764fb9f72c06a0f3710699fb6c9448.incomplete'
Downloading 'consolidated.00.pth' to '/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/blobs/d67a91807d5879d193a694da57f28ff85092e92dc9fbef4888bd05e22b15ab75.incomplete'
Downloading 'checklist.chk' to '/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/blobs/510da489e04e615c1b50e690e03a62d3bbff9fd9.incomplete'
Downloading 'Responsible-Use-Guide.pdf' to '/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/blobs/525dc349d71fe257fce4098c146446df6fef4247174f351381e4c3214af126f0.incomplete'
Downloading '.gitattributes' to '/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/blobs/432012a5e6ec946e6c1cb318f256223889e3ab44.incomplete'
Downloading 'USE_POLICY.md' to '/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/blobs/abbcc199b2d1e4feb5d7e40c0bd67e1b0ce2

In [7]:
!mkdir ./llama-2-7b

In [8]:
!mv /root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b/snapshots/69656aac4cb47911a639f5890ff35b41ceb82e98/* ./llama-2-7b/

In [15]:
checkpoints_dir = "/content/llama-2-7b"
checkpoints = sorted(Path(checkpoints_dir).glob('*.pth'))
for c in checkpoints:
  print(c)
assert len(checkpoints) > 0, f"No checkpoints found in {checkpoints_dir}"
# from huggingface_hub import snapshot_download
# snapshot_download(repo_id="meta-llama/Llama-2-7b", local_dir=local_dir)

/content/llama-2-7b/consolidated.00.pth


In [10]:
!ls -a /content/llama-2-7b/

.	       consolidated.00.pth  README.md		       tokenizer.model
..	       LICENSE.txt	    Responsible-Use-Guide.pdf  USE_POLICY.md
checklist.chk  params.json	    tokenizer_checklist.chk


In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

model = LLaMA.build(checkpoints_dir="/content/llama-2-7b/", tokenizer_path='tokenizer.model', load_model=True, max_seq_len=1024, max_batch_size=3, device=device)

Loading model from /content/llama-2-7b/consolidated.00.pth


FileNotFoundError: [Errno 2] No such file or directory: '/content/llama-2-7b/consolidated.00.pth'

In [None]:
prompts = [""]

# Inference of the model