In [1]:
#libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
from dataclasses import dataclass

In [2]:
class RMSNorm(nn.Module):
  def __init__(self,dim:int,eps:float=1e-6):
    super().__init__()
    self.eps=eps
    self.weight=nn.Parameter(torch.ones(dim))

  def _norm(self,x:torch.Tensor):
    #x->[b,d,s] 1/sqrt(mean(x^2))->[b,d,s] x*1/sqrt()->[b,d,s]
    return x*torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps) #torch.rsqrt(x)=1/root(x)

  def forward(self,x:torch.Tensor):
    return self.weight*self._norm(x.float()).type_as(x)

In [3]:
def precompute_theta_pos_frequencies(head_dim:int,seq_len:int,device:str,theta:float=10000.0):
  #head_dim is dim after distributing to heads
  assert head_dim%2==0,"head dim must be even number as defined in original repo"
  theta_numerator=torch.arange(0,head_dim,2).float() # this 2(i-1) part->[0,2,4,6,8...]
  theta=1.0/(theta**(theta_numerator/head_dim)).to(device) #1/10000^(theta_num)
  m=torch.arange(seq_len,device=device) #m=[0,1,2,3,4...]
    # m->seq_len theta->head_dim/2
    # m outer theta ->[seq_len,head_dim/2]
    # a=[1,2,3] b=[1,2,3] -> [[1,2,3],[2,4,6],[3,6,9]]
    # every m is multiplied to every value of thetha
  freqs=torch.outer(m,theta).float()
    #convert [1,angle] -> [1*cos(angle)+i(1*sin(angle))]
  freqs_complex=torch.polar(torch.ones_like(freqs),freqs)
  return freqs_complex

In [4]:
def apply_rotary_embeddings(x:torch.Tensor,freqs_complex:torch.Tensor,device:str):
  #[b,s,h,hdim]->[b,s,h,hdim/2,2]->convert to complex->[b,s,h,hdim/2]
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) # Two consecutive values will become a single complex number
  #[s,h_dim/2]->[1,s,1,h_dim] (b,s,h,h_dim)
  freqs_complex=freqs_complex.unsqueeze(0).unsqueeze(2)
  #[x1+ix2]*[cos(m1theta+i sin(m1thetha))] [b,s,h,h_dim/2]*[1,s,1,h_dim/2]->[b,s,h,h_dim/2]
  x_rotated=x_complex*freqs_complex
  #[x1+ix2]->[x1,x2] [b,s,h,h_dim/2]->[b,s,h,h_dim/2,2]
  x_out=torch.view_as_real(x_rotated)
  #[b,s,h,h_dim/2,2]->[b,s,h,h_dim]
  x_out=x_out.reshape(*x.shape)
  return x_out.type_as(x).to(device)

In [5]:
def repeat_kv(x:torch.Tensor,n_rep:int):
  batch_size,seq_len,n_kv_heads,head_dim=x.shape
  if(n_rep==1):
    return x
  else:
    return (
        x[:,:,:,None,:]#[b,s,h,d]->[b,s,h,1,d]
        .expand(batch_size,seq_len,n_kv_heads,n_rep,head_dim)#[b,s,h,1,d]->[b,s,h,n_rep,d]
        .reshape(batch_size,seq_len,n_kv_heads*n_rep,head_dim)#[b,s,h,n_rep,d]->[b,s,h*n_rep,d]
    )

In [6]:
#saving model args
@dataclass
class ModelArgs:
  dim:int=4096
  n_layers:int=32
  n_heads:int=32
  n_kv_heads:Optional[int]=None
  vocab_size:int=-1 #will be loaded by build function from original weights of model
  multiple_of:int=256 #hidden dim of ff_layer will be multiple of this
  ff_dim_multiplier:Optional[float]=None #to make hidden_dim a closest multiplier of this
  norm_eps:float=1e-5

  #for kv cache
  max_batch_size=32
  max_seq_len=2048

  device:str=None

In [7]:
class SelfAttention(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    self.heads_q=args.n_heads
    self.n_kv_heads=args.n_heads if args.n_kv_heads is None else args.n_kv_heads
    self.head_dim=args.dim//args.n_heads
    self.n_rep=self.n_heads_q//self.n_kv_heads

    self.wq=nn.Linear(args.dim,args.n_heads*self.head_dim,bias=False)
    self.wk=nn.Linear(args.dim,self.n_kv_heads*self.head_dim,bias=False)
    self.wq=nn.Linear(args.dim,self.n_kv_heads*self.head_dim,bias=False)
    self.wo=nn.Linear(args.n_heads*self.head_dim,args.dim,bias=False)

    #to store k and v
    self.cache_k=torch.zeros((args.max_batch_size,args.max_seq_len,self.n_kv_heads,self.head_dim))
    self.cache_v=torch.zeros((args.max_batch_size,args.max_seq_len,self.n_kv_heads,self.head_dim))

  def forward(self,x:torch.Tensor,start_pos:int,freqs_complex:torch.Tensor):
    batch_size,seq_len,_=x.shape #x=[b,s,d]

    #s=1 as we deal with one token at a time
    xq=self.wq(x) #[b,s,d]->[b,s,h*h_dim]
    xk=self.wk(x) #[b,s,d]->[b,s,h_kv*h_dim]
    xv=self.wv(x) #[b,s,d]->[b,s,h_kv*h_dim]

    #[b,s,h*h_dim]->[b,s,h,h_dim]
    xq=xq.view(batch_size,seq_len,self.n_heads_q,self.head_dim)
    #[b,s,h_kv*h_dim]->[b,s,h_kv,h_dim]
    xk=xk.view(batch_size,seq_len,self.n_kv_heads,self.head_dim)
    #[b,s,h_kv*h_dim]->[b,s,h_kv,h_dim]
    xv=xv.view(batch_size,seq_len,self.n_kv_heads,self.head_dim)

    #[b,s,h_q,h_dim]->[b,s,h_q,h_dim]
    xq=apply_rotary_embeddings(xq,freqs_complex,x.device)
    #[b,s,h_kv,h_dim]->[b,s,h_kv,h_dim]
    xk=apply_rotary_embeddings(xk,freqs_complex,x.device)

    #put new entry in kv_cache
    self.cache_k[:batch_size,start_pos:start_pos+seq_len]=xk
    self.cache_v[:batch_size,start_pos:start_pos+seq_len]=xv

    #gets the k and v from cache
    #keys->[b,s_kv,h_kv,h_dim]
    keys=self.cache_k[:batch_size,:start_pos+seq_len]
    #values->[b,s_kv,h_kv,h_dim]
    values=self.cache_v[:batch_size,:start_pos+seq_len]

    #[b,s_kv,h_kv,h_dim]->[b,s_kv,h_kv*n_rep,h_dim]
    keys=repeat_kv(keys,self.n_rep)
    #[b,s_kv,h_kv,h_dim]->[b,s_kv,h_kv*n_rep,h_dim]
    values=repeat_kv(values,self.n_rep)

    #[b,s,h_q,h_dim]->[b,h_q,s,h_dim] s=1
    xq=xq.transpose(1,2)
    #[b,s_kv,h_q,h_dim]->[b,h_q,s_kv,h_dim]
    keys=keys.transpose(1,2)
    #[b,s_kv,h_q,h_dim]->[b,h_q,s_kv,h_dim]
    values=values.transpose(1,2)

    # k=[b,h_q,s_kv,h_dim]->[b,h_q,h_dim,s_kv]
    #[b,h_q,s,h_dim]*[b,h_q,h_dim,s_kv]=[b,h_q,s,s_kv]
    score=torch.matmul(xq,keys.transpose(2,3))/math.sqrt(self.head_dim)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)

    #[b,h_q,s,s_kv]*[b,h_q,s_kv,h_dim]->[b,h,s,h_dim]->[b,s,h,h_dim]->[b,s,dim]
    output=torch.matmul(scores,values)
    output=(output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
    return self.wo(output)

In [8]:
class FeedForward(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    #we are doing all this to increase the parameters in model
    hidden_dim=4*args.dim
    hidden_dim=int(2*hidden_dim/3)
    if args.ffn_dim_multiplier is not None:
      hidden_dim=int(args.ffn_dim_multiplier*hidden_dim)
    #round off to next multiple of multipier
    hidden_dim=args.multiple_of*((hidden_dim+args.multiple_of-1)//args.multiple_of)
    # hidden 7 round off to multiple of 5
    #7+5-1=11 11//5=2. 2*5=10

    self.w1=nn.Linear(args.dim,hidden_dim,bias=False)
    self.w3=nn.Linear(hidden_dim,args.dim,bias=False)
    self.w3=nn.Linear(args.dim,hidden_dim,bias=False)

  def forward(self,x:torch.Tensor):
    swish=F.silu(self.w1(x))
    x_V=self.w3(x)
    x=swish*x_V
    x=self.w2(x)
    return x

In [9]:
class EncoderBlock(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    self.n_heads=args.n_heads
    self.dim=args.dim
    self.head_dim=args.dim//args.n_heads

    self.attention=SelfAttention(args)
    self.feed_forward=FeedForward(args)

    self.attention_norm=RMSNorm(self.dim,eps=args.norm_eps)
    self.ffn_norm=RMSNorm(self.dim,eps=args.norm_eps)

  def forward(self,x:torch.Tensor,start_pos:int,freqs_complex:torch.Tensor):
    h=x+self.attention.forward(self.attention_norm(x),start_pos,freqs_complex)
    out=h+self.feed_forward.forward(self.ffn_norm(h))
    return out

In [10]:
class Transformer(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    assert args.vocab_size!=-1,"vocab size must be set"

    self.args=args
    self.vocab_size=args.vocab_size
    self.n_layers=args.n_layers

    self.tok_embedding=nn.Embedding(self.vocab_size,args.dim)

    self.layers=nn.ModuleList()
    for layer in range(args.n_layers):
      self.layers.append(EncoderBlock(args))

    self.norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

    self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)

  def forward(self,tokens:torch.Tensor,start_pos:int):
    batch_size, seq_len = tokens.shape
    assert seq_len == 1, "Only one token at a time can be processed"

    h = self.tok_embeddings(tokens)

    # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
    freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

    # Consecutively apply all the encoder layers
    for layer in self.layers:
      h = layer(h, start_pos, freqs_complex)
    h = self.norm(h)
    output = self.output(h).float()
    return output

In [12]:
import tqdm
import json
from pathlib import Path
from sentencepiece import SentencePieceProcessor

In [20]:
class LLaMA:
  def __init__(self,model:Transformer,tokenizer:SentencePieceProcessor,model_args:ModelArgs):
    self.model=model,
    self.tokeizer=tokenizer,
    self.args=model_args

  @staticmethod
  def build(checkpoints_dir:str,tokenizer_path:str,load_model:bool,max_seq_len:int,max_batch_size:int,device:str):
    if load_model:
      checkpoints=sorted(Path(checkpoints_dir).glob("*.pth"))#loads all files with .pth and sort them
      assert len(checkpoints) > 0, f"no checkpoint files found mf"
      ckpt_path=checkpoints[0]
      print("loading checkpoints")
      checkpoint=torch.load(ckpt_path,map_location="cpu")
    with open(Path(checkpoints_dir) / "params.json", "r") as f:
      params = json.loads(f.read())

    model_args=ModelArgs(
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        device=device,
        **params
    )
    tokenizer=SentencePieceProcessor()#intialize tokeizer
    tokenizer.load(tokenizer_path)#load it with pretrained weights
    model_args.vocab_size=tokenizer.vocab_size()

    if device == "cuda":
      torch.set_default_tensor_type(torch.cuda.HalfTensor)
    else:
      torch.set_default_tensor_type(torch.BFloat16Tensor)

    model=Transformer(model_args)#intilalize model

    if load_model:
    # The only unmatched key in the checkpoint is rope.freqs. Remove it
      del checkpoint['rope.freqs']
      model.load_state_dict(checkpoint, strict=True) #strict make sure all the varibales have same name

    return LLaMA(model,tokenizer,model_args)

  def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None):
        if max_gen_len is None:
            max_gen_len = self.args.max_seq_len - 1
        # Convert each prompt into tokens
        prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts]
        # Make sure the batch size is not too large
        batch_size = len(prompt_tokens)
        assert batch_size <= self.args.max_batch_size, f"batch size must be less than or equal to {self.args.max_batch_size}"
        max_prompt_len = max(len(prompt) for prompt in prompt_tokens)
        # Make sure the prompt length is not larger than the maximum sequence length
        assert max_prompt_len <= self.args.max_seq_len, f"prompt length must be less than or equal to {self.args.max_seq_len}"
        total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len)

        # Create the list that will contain the generated tokens, along with the initial prompt tokens
        pad_id = self.tokenizer.pad_id()
        tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device)
        for k, t in enumerate(prompt_tokens):
            # Populate the initial tokens with the prompt tokens
            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)

        eos_reached = torch.tensor([False] * batch_size, device=device)
        prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise
        cur_iterator = tqdm(range(1, total_len), desc="Generating tokens")
        for cur_pos in cur_iterator:
            with torch.no_grad():
                logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos)
            if temperature > 0:
                # The temperature is applied before the softmax
                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
                next_token = self._sample_top_p(probs, top_p)
            else:
                # Greedily select the token with the max probability
                next_token = torch.argmax(logits[:, -1], dim=-1)

            next_token = next_token.reshape(-1)
            # Only replace token if it is a padding token
            next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token)
            tokens[:, cur_pos] = next_token
            # EOS is reached only if we found an EOS token for a padding position
            eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id)
            if all(eos_reached):
                break
        out_tokens = []
        out_text = []
        for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()):
            # Cut to the EOS token, if present
            if self.tokenizer.eos_id in current_prompt_tokens:
                eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id)
                current_prompt_tokens = current_prompt_tokens[:eos_idx]
            out_tokens.append(current_prompt_tokens)
            out_text.append(self.tokenizer.decode(current_prompt_tokens))
        return (out_tokens, out_text)

  def _sample_top_p(self, probs, p):
        # (B, vocab_size)
        probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
        # (B, vocab_size)
        probs_sum = torch.cumsum(probs_sort, dim=-1)
        # (B, vocab_size)
        # (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking)
        mask = probs_sum - probs_sort > p
        # Zero out all the probabilities of tokens that are not selected by the Top P
        probs_sort[mask] = 0.0
        # Redistribute the probabilities so that they sum up to 1.
        probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
        # Sample a token (its index) from the top p distribution
        next_token = torch.multinomial(probs_sort, num_samples=1)
        # Get the token position in the vocabulary corresponding to the sampled index
        next_token = torch.gather(probs_idx, -1, next_token)
        return next_token


In [17]:
if __name__ == '__main__':
    torch.manual_seed(0)

    allow_cuda = False
    device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu'

    prompts = [
        "What is Mark Zuckerberg was born in india ",
        # Few shot promt
        """Translate English to French:

        goodmorning=>bonjour
        peppermint => menthe poivrée
        plush girafe => girafe peluche
        cheese =>""",
    ]

    model = LLaMA.build(
        checkpoints_dir='llama-2-7b/',
        tokenizer_path='tokenizer.model',
        load_model=True,
        max_seq_len=1024,
        max_batch_size=len(prompts),
        device=device
    )

    out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64))
    assert len(out_texts) == len(prompts)
    for i in range(len(out_texts)):
        print(f'{out_texts[i]}')
        print('-' * 50)

torch.Size([2, 7])
