In [2]:
from contextlib import nullcontext

In [3]:
from typing import Optional,Tuple

In [4]:
import math

In [5]:
import torch

In [6]:
import torch.nn as nn

In [7]:
import torch.nn.functional as F

In [8]:
from base_llama import LlamaPreTrainedModel, LlamaConfig
from rope import apply_rotary_emb
from utils import *

In [37]:
class RMSNorm(nn.Module):

    def __init__(self,dim,eps):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(dim))
        self.eps = eps
    def forward(self,x):
        return  self.weights * x / ((x ** 2).mean(dim=-1,keepdim=True) ** 0.5 + self.eps)

In [29]:
x = torch.rand(5,4)

In [33]:
rms = RMSNorm(4,0.000000001)

In [34]:
(rms(x) ** 2).mean(dim=-1)

tensor([[0.5115],
        [0.6114],
        [0.5034],
        [0.5390],
        [0.7196]])


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<MeanBackward1>)

In [42]:
def repeat(x,dim,n):
    B,T,head,head_dim = x.shape
    return x[:,:,:,None,:].expand(B,T,head,n,head_dim).reshape(B,T,head*n,head_dim)

In [48]:
class Attention(nn.Module):
    def __init__(self,config:LlamaConfig):
        ### group attention
        super().__init__()
        self.dim = config.dim
        self.hidden_dim = config.hidden_dim
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        assert self.n_heads % self.n_kv_heads == 0
        self.n_rep = self.n_heads // self.n_kv_heads
        self.head_dim = self.dim // self.n_heads
        self.wq = nn.Linear(self.dim,self.n_heads * self.head_dim,bias=False)
        self.wk = nn.Linear(self.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wv = nn.Linear(self.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wo = nn.Linear(self.n_heads * self.head_dim,self.dim,bias=False)
        
    def compute(self,q,k,v):
        B,head,T,head_dim = q.shape
        scores = q @ k.transpose(2,3) # B,head,T,T
        mask = torch.tril(torch.ones(T,T))
        scores = scores.masked_fill(mask == 0.,'-inf')
        scores = F.softmax(scores,dim=-1)
        return scores @ v
    
        
    def forward(self,x):
        ### x (B,T,dim)
        B,T,dim = x.shape
        q = self.wq(x).view(B,T,self.n_heads,self.head_dim)
        k = self.wk(x).view(B,T,self.n_kv_heads,self.head_dim)
        v = self.wv(x).view(B,T,self.n_kv_heads,self.head_dim)
        k = repeat(k)
        v = repeat(v)
        # B,T,head,head_dim
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        # B,head,T,head_dim
        x = self.compute(q,k,v)
        # B,T dim
        return self.wo(x)

        

In [49]:
class FeedForward(nn.Module):

    def __init__(self,dim:int, hidden_dim:int, multiple_of:int, dropout:float):
        super().__init__()
        self.w1 = nn.Linear(dim,hidden_dim)
        self.w2 = nn.Linear(hidden_dim,dim)
        self.w3 = nn.Linear(dim,hidden_dim)
    def SiluGlu(self,x):
        return F.silu(self.w1(x)) * self.w3(x)
    def forward(self,x):
        return self.w2(self.SiluGlu(x))

In [53]:
class LlamaLayer(nn.Module):
    def __init__(self,layer_id:int,config:LlamaConfig):
        super().__init__()
        self.layer_id = layer_id
        self.attn_norm = RMSNorm(config.dim,eps=0.00001)
        self.feed_norm = RMSNorm(config.dim,eps=0.00001)
        self.attn = Attention(config)
        self.feed = FeedForward(config.dim,config.hidden_dim,config.multiple_of,config.dropout)
        
    def forward(self,x):
        x = x + self.attn(self.attn_norm(x))
        x = x + self.feed(self.feed_norm(x))
        return x

In [56]:
class Llama(LlamaPreTrainedModel):
    def __init__(self,config: LlamaConfig):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(config.n_layers):
            self.layers.append(LlamaLayer(i,config))
        self.norm = RMSNorm(config.dim,eps=0.000001)
        self.wo = nn.Linear(config.dim,config.vocab_size)
        self.word_embedding = nn.Embedding(config.vocab_size,config.dim)
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.wo(x)
        return logits

    def generate(self,tokens,max_len):
        # B,T
        for _ in range(max_len):
            x = self.word_embedding(tokens) # B,T,dim
            logits = self.forward(x) # B,T,vocabsize
            logits = logits[:,-1,:] # B, vocabsize
            probs = F.softmax(logits,dim=-1)
            idx = torch.multinomial(probs,num_samples=1) # B,1
            tokens = torch.cat((torch,idx),dim=-1)
        return  tokens
        
        

In [15]:
def load_pretrained(checkpoint):
    pass