In [74]:
#libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
from dataclasses import dataclass

In [76]:
class RMSNorm(nn.Module):
  def __init__(self,dim:int,eps:float=1e-6):
    super().__init__()
    self.eps=eps
    self.weight=nn.Parameter(torch.ones(dim))

  def _norm(self,x:torch.Tensor):
    #x->[b,d,s] 1/sqrt(mean(x^2))->[b,d,s] x*1/sqrt()->[b,d,s]
    return x*torch.rsqrt(x.pow(2).mean(-1,keepdim=True)+self.eps) #torch.rsqrt(x)=1/root(x)

  def forward(self,x:torch.Tensor):
    return self.weight*self._norm(x.float()).type_as(x)

In [77]:
def precompute_theta_pos_frequencies(head_dim:int,seq_len:int,device:str,theta:float=10000.0):
  #head_dim is dim after distributing to heads
  assert head_dim%2==0,"head dim must be even number as defined in original repo"
  theta_numerator=torch.arange(0,head_dim,2).float() # this 2(i-1) part->[0,2,4,6,8...]
  theta=1.0/(theta**(theta_numerator/head_dim)).to(device) #1/10000^(theta_num)
  m=torch.arange(seq_len,device=device) #m=[0,1,2,3,4...]
    # m->seq_len theta->head_dim/2
    # m outer theta ->[seq_len,head_dim/2]
    # a=[1,2,3] b=[1,2,3] -> [[1,2,3],[2,4,6],[3,6,9]]
    # every m is multiplied to every value of thetha
  freqs=torch.outer(m,theta).float()
    #convert [1,angle] -> [1*cos(angle)+i(1*sin(angle))]
  freqs_complex=torch.polar(torch.ones_like(freqs),freqs)
  return freqs_complex

In [78]:
def apply_rotary_embeddings(x:torch.Tensor,freqs_complex:torch.Tensor,device:str):
  #[b,s,h,hdim]->[b,s,h,hdim/2,2]->convert to complex->[b,s,h,hdim/2]
  x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) # Two consecutive values will become a single complex number
  #[s,h_dim/2]->[1,s,1,h_dim] (b,s,h,h_dim)
  freqs_complex=freqs_complex.unsqueeze(0).unsqueeze(2)
  #[x1+ix2]*[cos(m1theta+i sin(m1thetha))] [b,s,h,h_dim/2]*[1,s,1,h_dim/2]->[b,s,h,h_dim/2]
  x_rotated=x_complex*freqs_complex
  #[x1+ix2]->[x1,x2] [b,s,h,h_dim/2]->[b,s,h,h_dim/2,2]
  x_out=torch.view_as_real(x_rotated)
  #[b,s,h,h_dim/2,2]->[b,s,h,h_dim]
  x_out=x_out.reshape(*x.shape)
  return x_out.type_as(x).to(device)

In [79]:
def repeat_kv(x:torch.Tensor,n_rep:int):
  batch_size,seq_len,n_kv_heads,head_dim=x.shape
  if(n_rep==1):
    return x
  else:
    return (
        x[:,:,:,None,:]#[b,s,h,d]->[b,s,h,1,d]
        .expand(batch_size,seq_len,n_kv_heads,n_rep,head_dim)#[b,s,h,1,d]->[b,s,h,n_rep,d]
        .reshape(batch_size,seq_len,n_kv_heads*n_rep,head_dim)#[b,s,h,n_rep,d]->[b,s,h*n_rep,d]
    )

In [75]:
#saving model args
@dataclass
class ModelArgs:
  dim:int=4096
  n_layers:int=32
  n_heads:int=32
  n_kv_heads:Optional[int]=None
  vocab_size:int=-1 #will be loaded by build function from original weights of model
  multiple_of:int=256 #hidden dim of ff_layer will be multiple of this
  ff_dim_multiplier:Optional[float]=None #to make hidden_dim a closest multiplier of this
  norm_eps:float=1e-5

  #for kv cache
  max_batch_size=32
  max_seq_len=2048

  device:str=None

In [81]:
class SelfAttention(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    self.heads_q=args.n_heads
    self.n_kv_heads=args.n_heads if args.n_kv_heads is None else args.n_kv_heads
    self.head_dim=args.dim//args.n_heads
    self.n_rep=self.n_heads_q//self.n_kv_heads

    self.wq=nn.Linear(args.dim,args.n_heads*self.head_dim,bias=False)
    self.wk=nn.Linear(args.dim,self.n_kv_heads*self.head_dim,bias=False)
    self.wq=nn.Linear(args.dim,self.n_kv_heads*self.head_dim,bias=False)
    self.wo=nn.Linear(args.n_heads*self.head_dim,args.dim,bias=False)

    #to store k and v
    self.cache_k=torch.zeros((args.max_batch_size,args.max_seq_len,self.n_kv_heads,self.head_dim))
    self.cache_v=torch.zeros((args.max_batch_size,args.max_seq_len,self.n_kv_heads,self.head_dim))

  def forward(self,x:torch.Tensor,start_pos:int,freqs_complex:torch.Tensor):
    batch_size,seq_len,_=x.shape #x=[b,s,d]

    #s=1 as we deal with one token at a time
    xq=self.wq(x) #[b,s,d]->[b,s,h*h_dim]
    xk=self.wk(x) #[b,s,d]->[b,s,h_kv*h_dim]
    xv=self.wv(x) #[b,s,d]->[b,s,h_kv*h_dim]

    #[b,s,h*h_dim]->[b,s,h,h_dim]
    xq=xq.view(batch_size,seq_len,self.n_heads_q,self.head_dim)
    #[b,s,h_kv*h_dim]->[b,s,h_kv,h_dim]
    xk=xk.view(batch_size,seq_len,self.n_kv_heads,self.head_dim)
    #[b,s,h_kv*h_dim]->[b,s,h_kv,h_dim]
    xv=xv.view(batch_size,seq_len,self.n_kv_heads,self.head_dim)

    #[b,s,h_q,h_dim]->[b,s,h_q,h_dim]
    xq=apply_rotary_embeddings(xq,freqs_complex,x.device)
    #[b,s,h_kv,h_dim]->[b,s,h_kv,h_dim]
    xk=apply_rotary_embeddings(xk,freqs_complex,x.device)

    #put new entry in kv_cache
    self.cache_k[:batch_size,start_pos:start_pos+seq_len]=xk
    self.cache_v[:batch_size,start_pos:start_pos+seq_len]=xv

    #gets the k and v from cache
    #keys->[b,s_kv,h_kv,h_dim]
    keys=self.cache_k[:batch_size,:start_pos+seq_len]
    #values->[b,s_kv,h_kv,h_dim]
    values=self.cache_v[:batch_size,:start_pos+seq_len]

    #[b,s_kv,h_kv,h_dim]->[b,s_kv,h_kv*n_rep,h_dim]
    keys=repeat_kv(keys,self.n_rep)
    #[b,s_kv,h_kv,h_dim]->[b,s_kv,h_kv*n_rep,h_dim]
    values=repeat_kv(values,self.n_rep)

    #[b,s,h_q,h_dim]->[b,h_q,s,h_dim] s=1
    xq=xq.transpose(1,2)
    #[b,s_kv,h_q,h_dim]->[b,h_q,s_kv,h_dim]
    keys=keys.transpose(1,2)
    #[b,s_kv,h_q,h_dim]->[b,h_q,s_kv,h_dim]
    values=values.transpose(1,2)

    # k=[b,h_q,s_kv,h_dim]->[b,h_q,h_dim,s_kv]
    #[b,h_q,s,h_dim]*[b,h_q,h_dim,s_kv]=[b,h_q,s,s_kv]
    score=torch.matmul(xq,keys.transpose(2,3))/math.sqrt(self.head_dim)
    scores = F.softmax(scores.float(), dim=-1).type_as(xq)

    #[b,h_q,s,s_kv]*[b,h_q,s_kv,h_dim]->[b,h,s,h_dim]->[b,s,h,h_dim]->[b,s,dim]
    output=torch.matmul(scores,values)
    output=(output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1))
    return self.wo(output)

In [83]:
class FeedForward(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    #we are doing all this to increase the parameters in model
    hidden_dim=4*args.dim
    hidden_dim=int(2*hidden_dim/3)
    if args.ffn_dim_multiplier is not None:
      hidden_dim=int(args.ffn_dim_multiplier*hidden_dim)
    #round off to next multiple of multipier
    hidden_dim=args.multiple_of*((hidden_dim+args.multiple_of-1)//args.multiple_of)
    # hidden 7 round off to multiple of 5
    #7+5-1=11 11//5=2. 2*5=10

    self.w1=nn.Linear(args.dim,hidden_dim,bias=False)
    self.w3=nn.Linear(hidden_dim,args.dim,bias=False)
    self.w3=nn.Linear(args.dim,hidden_dim,bias=False)

  def forward(self,x:torch.Tensor):
    swish=F.silu(self.w1(x))
    x_V=self.w3(x)
    x=swish*x_V
    x=self.w2(x)
    return x

In [84]:
class EncoderBlock(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    self.n_heads=args.n_heads
    self.dim=args.dim
    self.head_dim=args.dim//args.n_heads

    self.attention=SelfAttention(args)
    self.feed_forward=FeedForward(args)

    self.attention_norm=RMSNorm(self.dim,eps=args.norm_eps)
    self.ffn_norm=RMSNorm(self.dim,eps=args.norm_eps)

  def forward(self,x:torch.Tensor,start_pos:int,freqs_complex:torch.Tensor):
    h=x+self.attention.forward(self.attention_norm(x),start_pos,freqs_complex)
    out=h+self.feed_forward.forward(self.ffn_norm(h))
    return out

In [85]:
class Transformer(nn.Module):
  def __init__(self,args:ModelArgs):
    super().__init__()
    assert args.vocab_size!=-1,"vocab size must be set"

    self.args=args
    self.vocab_size=args.vocab_size
    self.n_layers=args.n_layers

    self.tok_embedding=nn.Embedding(self.vocab_size,args.dim)

    self.layers=nn.ModuleList()
    for layer in range(args.n_layers):
      self.layers.append(EncoderBlock(args))

    self.norm = RMSNorm(args.dim, eps=args.norm_eps)
    self.output = nn.Linear(args.dim, self.vocab_size, bias=False)

    self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)

  def forward(self,tokens:torch.Tensor,start_pos:int):
    batch_size, seq_len = tokens.shape
    assert seq_len == 1, "Only one token at a time can be processed"

    h = self.tok_embeddings(tokens)

    # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
    freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]

    # Consecutively apply all the encoder layers
    for layer in self.layers:
      h = layer(h, start_pos, freqs_complex)
    h = self.norm(h)
    output = self.output(h).float()
    return output