In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import math
from dataclasses import dataclass
from typing import Optional


In [2]:
@dataclass
class ModelArgs:
    dim :int = 4096
    n_layers : int=32
    n_heads : int=32
    n_kv_heads : Optional[int]=None
    vocab_size : int=-1
    norm_eps : float=1e-5
    ffn_dim_multiplier : int=4
    #for KV caching
    max_batch_size: int=32
    max_seq_len: int=2048
    
    device:str=None

In [3]:
device="cuda" if torch.cuda.is_available() else "cpu"

In [26]:
def precompute_theta_pos_frequencies(head_dim:int, seq_len: int, device: str, theta: float=10000.0):
    assert head_dim%2==0
    #theta_i=1000^-2(i-1)/d_head for i [1,2...d_head/2]
    i=torch.arange(0,head_dim,2,device=device) #0,2,4,6,8...
    theta=10000**(-2*i/head_dim) # (d_head/2) # theta_1,theta_2...theta_d_head/2
    
    m=torch.arange(seq_len,device=device) #(seq_len,)#0,1,2,3,4...seq_len
    
    m_theta=torch.outer(m,theta).float() #(seq_len,d_head/2)
    
    freq_complex=torch.polar(torch.ones_like(m_theta),m_theta) #(seq_len, d_head/2)
    
    return freq_complex 

def apply_rotary_embeddings(x: torch.Tensor, freq_complex: torch.Tensor, device:str):
    #torch.view_as_complex needs a tensor with shape (N,2)
    #(N,seq_len,h,head_dim)
    x_complex=torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2)) #N,seq_len,h,head_dim/2,2 
    freq_complex=freq_complex.unsqueeze(0).unsqueeze(2) # 1,seq_len,1,head_dim/2
    
    x_rotated=x_complex * freq_complex #(N,seq_len,h,head_dim/2)
    x_out=torch.view_as_real(x_rotated) #(N,seq_len,h,head_dim/2,2)
    x_out=x_out.reshape(*x.shape)
    
    return x_out.type_as(x).to(device)
        
    
def test_precompute_theta_pos_frequencies():
    device = "cpu"

    # Test case 1: Small head_dim and seq_len
    head_dim, seq_len = 4, 5
    freq_complex = precompute_theta_pos_frequencies(head_dim, seq_len, device)
    assert freq_complex.shape == (seq_len, head_dim // 2), f"Unexpected shape: {freq_complex.shape}"

    # Test case 2: Larger head_dim and seq_len
    head_dim, seq_len = 8, 10
    freq_complex = precompute_theta_pos_frequencies(head_dim, seq_len, device)
    assert freq_complex.shape == (seq_len, head_dim // 2), f"Unexpected shape: {freq_complex.shape}"

    # Test case 3: Edge case with minimal dimensions
    head_dim, seq_len = 2, 1
    freq_complex = precompute_theta_pos_frequencies(head_dim, seq_len, device)
    assert freq_complex.shape == (seq_len, head_dim // 2), f"Unexpected shape: {freq_complex.shape}"

    print("✅ All test cases passed for precompute_theta_pos_frequencies!")

def test_apply_rotary_embeddings():
    device = "cpu"

    # Test case 1: Small tensor
    x = torch.randn(1, 5, 2, 4)  # (batch, seq_len, heads, head_dim)
    freq_complex = precompute_theta_pos_frequencies(4, 5, device)
    x_out = apply_rotary_embeddings(x, freq_complex, device)
    assert x_out.shape == x.shape, f"Unexpected shape: {x_out.shape}"

    # Test case 2: Larger tensor
    x = torch.randn(2, 10, 4, 8)  # (batch, seq_len, heads, head_dim)
    freq_complex = precompute_theta_pos_frequencies(8, 10, device)
    x_out = apply_rotary_embeddings(x, freq_complex, device)
    assert x_out.shape == x.shape, f"Unexpected shape: {x_out.shape}"

    # Test case 3: Edge case with minimal dimensions
    x = torch.randn(1, 1, 1, 2)  # (batch, seq_len, heads, head_dim)
    freq_complex = precompute_theta_pos_frequencies(2, 1, device)
    x_out = apply_rotary_embeddings(x, freq_complex, device)
    assert x_out.shape == x.shape, f"Unexpected shape: {x_out.shape}"

    print("✅ All test cases passed for apply_rotary_embeddings!")

# Run test cases
test_precompute_theta_pos_frequencies()
test_apply_rotary_embeddings()

 

✅ All test cases passed for precompute_theta_pos_frequencies!
✅ All test cases passed for apply_rotary_embeddings!


In [27]:
class RMSNorm(nn.Module):
    def __init__(self,dim:int, eps:float=1e-6):
        super(RMSNorm, self).__init__()
        self.eps=eps
        #gamma param
        self.w=nn.Parameter(torch.ones(dim))
    
    def norm(self,x:torch.Tensor):
        #(N,seq_len,d_model)
        
        return x * torch.rsqrt(torch.mean(x**2,-1, keepdim=True) + self.eps)
    
    def forward(self,x:torch.Tensor):
        return self.w * self.norm(x.float()).type_as(x)
    


- Grouped Query Attention: Multiple Query heads share the same Key-Value pairs
    - How to determine which group shares which Key-Val?
- Apply ROPE to Q and K matrices to encode positional information
- Key Value cache

- GQA works at both during training and inference unlike KV caching (which only performs during inference)

In [28]:
# GQA
import torch.nn as nn
class SelfAttention(nn.Module):
    def __init__(self,args:ModelArgs):
        #num_kv_groups : Number of key value groups( num_heads for MHA, 1 for MQA)
        super(SelfAttention, self).__init__()
        self.dim=args.dim
        self.n_heads=args.n_heads
        self.num_kv_heads=args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.num_heads_per_kv = args.n_heads // args.n_kv_heads # Ratio of Q to KV groups
        self.head_dim=args.dim//args.n_heads
        
        self.wq=nn.Linear(args.dim,args.n_heads * self.head_dim,bias=False) # output: batch_size,seq_len, num_heads*head_dim
        # Key and value projections (output dims are reduced for num_kv_heads)
        self.wk = nn.Linear(args.dim, self.num_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.num_kv_heads * self.head_dim, bias=False)
        self.wo=nn.Linear(args.n_heads*self.head_dim,args.dim,bias=False)
        
        self.cache_k=torch.zeros((args.max_batch_size,args.max_seq_len,self.num_kv_heads,self.head_dim))
        self.cache_v=torch.zeros((args.max_batch_size,args.max_seq_len,self.num_kv_heads,self.head_dim))
        
    
    def forward(self,x:torch.Tensor, freqs_complex:torch.Tensor, start_pos: int):
        batch_size,seq_len,_=x.shape
                
        q = self.wq(x).view(batch_size,seq_len,self.n_heads,self.head_dim)
        k = self.wk(x).view(batch_size,seq_len,self.num_kv_heads,self.head_dim)
        v =self.wv(x).view(batch_size,seq_len,self.num_kv_heads,self.head_dim)
        
        q=apply_rotary_embeddings(q,freqs_complex,device=device)
        k=apply_rotary_embeddings(k,freqs_complex,device=device)
        
        self.cache_k[:batch_size,start_pos:start_pos + seq_len]=k
        self.cache_v[:batch_size,start_pos:start_pos + seq_len]=v
        
        keys= self.cache_k[:batch_size, : start_pos + seq_len]
        values=self.cache_v[:batch_size,: start_pos + seq_len]
                
        if self.num_kv_heads < self.n_heads:
            k=k.repeat_interleave(self.num_heads_per_kv,dim=2) # (batch_size,seq_len,num_heads,head_dim)
            v=v.repeat_interleave(self.num_heads_per_kv,dim=2) # (batch_size,seq_len,num_heads,head_dim)
            
        
        
        q=q.transpose(1,2) #(batch_size,num_heads,seq_len,head_dim) 
        k=k.transpose(1,2) 
        v=v.transpose(1,2)
        
        
        attn_scores=torch.matmul(q,k.transpose(-1,-2)) / (self.head_dim**0.5)
        attn_probs=nn.Softmax(dim=-1)(attn_scores)
        #attn mask?
        # if attn_mask is not None:
        # attn_scores = attn_scores.masked_fill(attn_mask == 0, float('-inf'))
        output = torch.matmul(attn_probs, v)  # [batch_size, num_heads, seq_len, head_dim]
        
        output = output.transpose(1, 2).contiguous().view(batch_size,seq_len,-1)  # Reshape to original dims
        return self.wo(output)

In [29]:

class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super(FeedForward, self).__init__()
        hidden_dim=args.dim * 4
        if args.ffn_dim_multiplier is not None:
            hidden_dim *= args.ffn_dim_multiplier
            hidden_dim=int(hidden_dim)
        
        self.w1=nn.Linear(in_features=args.dim,out_features=hidden_dim,bias=False)
        self.w2=nn.Linear(in_features=hidden_dim,out_features=args.dim,bias=False)
        self.w3=nn.Linear(in_features=args.dim,out_features=hidden_dim,bias=False)
    
    def forward(self,x:torch.Tensor):
        swish=F.silu(self.w1(x))
        x_V=self.w3(x)
        x=swish * x_V
        x=self.w2(x)
        return x


        
        

In [30]:
class EncoderBlock(nn.Module):
    def __init__(self,args):
        super(EncoderBlock, self).__init__()
        self.n_heads=  args.n_heads    
        self.dim=args.dim
        self.head_dim=args.dim//args.n_heads
        
        self.attention=SelfAttention(args)
        self.feed_forward=FeedForward(args)
        
        #Norm BEFORE attention
        self.attention_norm=RMSNorm(dim=args.dim,eps=args.norm_eps)
        
        #Norm BEFORE ffn
        self.ffn_norm=RMSNorm(args.dim,args.norm_eps)
    
    def forward(self,x:torch.Tensor, start_pos:int, freqs_complex:torch.Tensor):
        
        h=x + self.attention(
            self.attention_norm(x),start_pos, freqs_complex
        ) # sublayer connection
        
        out=h + self.feed_forward(self.ffn_norm(h))
        
        return out

In [31]:
import copy
def clones(module,N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [36]:
class Llama(nn.Module):
    def __init__(self,args:ModelArgs):
        super(Llama, self).__init__()
        
        self.args=args
        self.vocab_size=args.vocab_size
        self.n_layers=args.n_layers
        self.tok_embeddings=nn.Embedding(self.vocab_size,args.dim)

        self.layers=clones(EncoderBlock(args),self.n_layers)
        self.norm=RMSNorm(args.dim,eps=args.norm_eps)
        self.output=nn.Linear(args.dim,args.vocab_size,bias=False)

        self.freqs_complex=precompute_theta_pos_frequencies(self.args.dim//self.args.n_heads,self.args.max_seq_len*2,device=self.args.device)
    
    def forward(self,tokens:torch.Tensor,start_pos: int):
        batch_size,seq_len=tokens.shape
        #tokens -> (batch_size,seq_len)
        h=self.tok_embeddings(tokens) #(batch_size,seq_len,d_model)
        freqs_complex=self.freqs_complex[start_pos:start_pos+seq_len]
        print(type(freqs_complex))
        print(freqs_complex.unsqueeze(0))
        for layer in self.layers:
            h=layer(h,start_pos,freqs_complex)
        
        h=self.norm(h)
        output=self.output(h).float()
        return output
        
    