In this notebook i will implement Llama2 from scratch and load pre trained weights and also make inferences . I will run it on cpu ( bcoz of size issue ) . I implemented Grouped query attention and KV cache.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


In [None]:
from dataclasses import dataclass
from typing import Optional

In [None]:
@dataclass
class ModelArgs:
  dim : int = 4096
  n_layers : int = 32
  n_heads : int = 32 #no. of heads for queries
  n_kv_heads : Optional[int] = None  # no. of heads for keys
  vocab_size : int = -1
  multiple_of : int = 256
  ffn_dim_multiplier : Optional[float] = None
  norm_eps : float = 1e-5
  max_batch_size : int = 32
  max_seq_len : int  = 2048
  device : str = None

NameError: name 'dataclass' is not defined

In [None]:
def precompute_theta_pos_frequencies(head_dim : int , seq_len : int , device : str , theta : float = 10000) :
  #head_Dim = dim / H
  assert head % 2 == 0
  theta_numerator = torch.arange(0 , head_dim , 2 ).float()
  m = torch.arange(0, seq_len , , device = device)
  freqs = torch.outer(m , theta ).float()
  freq_complex = torch.polar(torch.ones_like(freqs) ,  freqs )
  return freq_complex

SyntaxError: invalid syntax (<ipython-input-2-6dd62ba7c001>, line 5)

In [None]:
def apply_rotatry_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
   # (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2)
  x_complex = torch.view_as_complex(x.float().reshape(*x.shpae[:-1] , -1 , 2 ) )
  #(Seq_len , Head_Dim / 2 ) - > (1 , seq_len , 1 , Head_Dim/2)
  freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
  x_rotated = x_complex * freqs_complex
  x_out = torch.view_as_real(x_rotated)
  x_out = x_out.reshape(*x.shape)
  return x_out.type_as(x).to(device)

In [None]:
class RMSNorm(nn.Module):
  def __init__(self , dim : int , eps : float = 1e-6 )
      self.dim = dim
      self.eps = eps
      self.weights = nn.Parameter(torch.ones(dim))
  def _norm(self , x):
    return x * torch.rsqrt((x.pow(2)).mean(-1 , keepd_dim = True ) + self.eps)
  def forward(self , x ):
    return self.weights * self._norm(x)


In [None]:
class FeedForward(nn.Module):
  def __init__(self , args : ModelArgs):
    super().__init__()
    hidden_dim = 4 * args.dim
    hidden_size = int(2* hidden_dim / 3)
    if args.ffn_dim_multiplier is not None :
      hidden_dim = int(args.ffn_dim_multiplier * hidden_dim )
    hidden = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
    self.w1 = nn.Linear(args.dim , hidden_dim , bias = False )
    self.w3 = nn.Linear(args.dim , hidden_dim , bias = False)
    self.w2 = nn.Linear(hidden_dim , args.dim , bias = False)
  def forward(self , x : torch.Tensor ):
    swish = F.silu(self.w1(x))
    x_v = self.w3(x)
    return self.w2(swish * x_v )




In [None]:
def repeat_kv(x : torch.Tensor , n_rep = int ) -> torch.Tensor :
  batch_size , seq_len , n_kv_heads , head_dim = x.shape
  if n_rep == 1 :
    return x
  else :
    #(B , Seq_Len , N_KV_Heads , 1 , Head_Dim)
    return (
        x[: ,: , : , None , : ].expand(batch_size ,seq_len , n_kv_heads ,
         n_rep , head_dim ).reshape(batch_size ,seq_len , n_kv_heads * n_rep, head_dim)
    )

In [None]:
class SelfAttention(nn.Module):
  def __init__(self , args : ModelArgs):
    super().__init__()
    self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
    self.n_heads_q = args.n_heads

    self.head_dim = args.dim // args.n_heads
    self.n_rep = self.n_heads_q // self.n_kv_heads

    self.wq = nn.Linear(args.dim , self.n_heads * self.head_dim , bias = False )
    self.wk = nn.Linear(args.dim , self.n_kv_heads * self.head_dim , bias = False)
    self.wv = nn.Linear(args.dim , self.n_kv_heads * self.head_dim , bias = False )
    self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)

    self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
    self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
  def forward(self , x : torch.Tensor , start_pos : int ,freqs_complex : torch.Tensor):
    batch_size , seq_len , _ = x.shape
    xq = self.wq(x)
    xk = self.wk(x)
    xv = self.wv(x)
    ##
    ####
    xq = xq.view(batch_size ,seq_len , self.n_heads , self.head_dim)
    xv = xv.view(batch_size , seq_len , self.n_kv_heads , self.head_dim)
    xk = xk.view(batch_size , seq_len , self.n_kv_heads , self.head_dim)

    xq = apply_rotatry_embeddings(xq , freq_complex , device = x.device )
    xk = apply_rotatry_embeddings(xk , freq_complex , device = x.device )

    #replace the entry in the cache for this token
    self.cache_k[:batch_size , start_pos : start_pos + seq_len ] = xk     #seq_len = 1
    self.cache_v[:batch_size , start_pos : start_pos + seq_len  ]  = xv


    #retrieve all the cahced keys and values so far
    keys = self.cache_k[:batch_size , 0: start_pos + seq_len]
    values = self.cache_v[ : batch_size , 0 : start_pos + seq_len ]

    #(B , seq_len_kv , )
    #i just change the shape by repeating as implemented in llama code(however it may not be most optimized)

    #repeat the heads of K and V to reach the number of heads of queries
    keys = repeat_kv(keys , self.n_rep)

    values = repeat_kv(values  , self.n_rep)

    ########
    # (B , 1 , H_Q , Head_Dim) -- > (B , H_Q , 1 , Head_Dim)
    xq = xq.transpose(1 , 2)
    keys = keys.transpose(1 , 2)
    values = values.transpose(1 , 2)

    scores = torch.matmul(xq , keys.transpose(2,3))/ math.sqrt(self.head_dim)
    scores = F.softmax(scores.float() , dim = -1).type_as(xq)

    # B , H_Q , 1 , Seq_Len_KV
    out = torch.matmul(scores , values )
    #B , H_Q , 1 , Head_dim
    return self.w_o(out.contiguous().view(batch_size , seq_len  ,  -1 )) # batch , 1 , dim












In [None]:
class EncoderBlock(nn.Module):
  def __init__(self , , args: ModelArgs ):
    super().__init__()
    self.dim = args.dim
    self.n_heads = args.n_heads
    self.head_dim  = args.dim // args.n_heads



    self.self_attention = SelfAttention(args)
    self.feed_forward = FeedForward(args)
    self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.attention_norm = RMSNorm(args.dim , eps = args.norm_eps )

  def forward(self , x : torch.Tensor, start_pos: int, freqs_complex: torch.Tensor ):
    h = x + self.attention.forward(
            self.attention_norm(x), start_pos, freqs_complex
        )
    out = h + self.feed_forward.forward(self.ffn_norm(h))
    return out


In [None]:
class Transformer(nn.module):
  def __init__(self , args : ModelArgs )  -> None :
    super().__init__()
    assert args.vocab_size != -1 , " set the vocab size "
    self.vocab_size = args.vocab_size
    self.n_layers = args.n_layers
    self.tok_embeddings = nn.Embedding(self.vocab_size , args.dim)
    self.layers = nn.ModuleList()
    for _ in range(args.n_layers)
      self.layers.append(EncoderBlock(args))
    self.norm = RMSNorm(dim = args.dim , eps = args.norm_eps)
    self.output = nn.Linear(args.dim , self.vocab_size , bias = False)
    self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads , self.args.max_seq_len * 2 , device = self.args.device)
  def forward(self , tokens : torch.Tensor , start_pos  : int ):
    batch_size, seq_len = tokens.shape
    assert seq_len == 1, "Only one token at a time can be processed"

    h = self.tok_embeddings(tokens)

    freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]


    for layer in self.layers:
      h = layer(h, start_pos, freqs_complex)
    h = self.norm(h)
    output = self.output(h).float()
    return output



Inferencing

In [None]:
from typing import Opitonal
import torch
import time
from pathlib import Path
import json
from sentencepiece import SentencePieceProcessor
from tqdm import tqdm
import os

In [None]:
class LLaMa:
  def __init__(self , model : Transformer , tokenizer : SentencePieceProcessor , model_args : ModelArgs) -> None :
    self.model = model
    self.tokenizer = tokenizer
    self.args = model_args
  @staticmethod
  def build(checkpint)