In [1]:
from contextlib import nullcontext

In [2]:
from typing import Optional,Tuple

In [3]:
import math

In [4]:
import torch

In [5]:
import torch.nn as nn

In [6]:
import torch.nn.functional as F

In [7]:
from base_llama import LlamaPreTrainedModel, LlamaConfig
from rope import apply_rotary_emb
from utils import *

In [100]:
class RMSNorm(nn.Module):

    def __init__(self,dim,eps):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    def forward(self,x):
        return  self.weight * x / ((x ** 2).mean(dim=-1,keepdim=True) ** 0.5 + self.eps)

In [101]:
x = torch.rand(5,4)

In [102]:
rms = RMSNorm(4,0.000000001)

In [103]:
(rms(x) ** 2).mean(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<MeanBackward1>)

In [104]:
def repeat(x,dim,n):
    B,T,head,head_dim = x.shape
    return x[:,:,:,None,:].expand(B,T,head,n,head_dim).reshape(B,T,head*n,head_dim)

In [201]:
class Attention(nn.Module):
    def __init__(self,config:LlamaConfig):
        ### group attention
        # layers.0.attention.compute_query.weight
        # layers.0.attention.compute_key.weight
        # layers.0.attention.compute_value.weight
        # layers.0.attention.compute_output.weight
        super().__init__()
        self.dim = config.dim
        self.hidden_dim = config.hidden_dim
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        assert self.n_heads % self.n_kv_heads == 0
        self.n_rep = self.n_heads // self.n_kv_heads
        self.head_dim = self.dim // self.n_heads
        self.max_seq_len = config.max_seq_len
        self.compute_query = nn.Linear(self.dim,self.n_heads * self.head_dim,bias=False)
        self.compute_key = nn.Linear(self.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.compute_value = nn.Linear(self.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.compute_output = nn.Linear(self.n_heads * self.head_dim,self.dim,bias=False)
        
    def compute(self,q,k,v):
        B,head,T,head_dim = q.shape
        scores = q @ k.transpose(2,3) / math.sqrt(head_dim)# B,head,T,T
        mask = torch.tril(torch.ones(T,T))
        scores = scores.masked_fill(mask == 0.,float('-inf'))
        scores = F.softmax(scores,dim=-1) # B,head,T,head_dim
        return scores @ v
    
        
    def forward(self,x):
        ### x (B,T,dim)
        B,T,dim = x.shape
        q = self.compute_query(x).view(B,T,self.n_heads,self.head_dim)
        k = self.compute_key(x).view(B,T,self.n_kv_heads,self.head_dim)
        v = self.compute_value(x).view(B,T,self.n_kv_heads,self.head_dim)
        k = repeat(k,2,self.n_rep)
        v = repeat(v,2,self.n_rep)
        q, k = apply_rotary_emb(q, k, self.head_dim, self.max_seq_len)
        # B,T,head,head_dim
        q = q.transpose(1,2)
        k = k.transpose(1,2)
        v = v.transpose(1,2)
        # B,head,T,head_dim
        x = self.compute(q,k,v).transpose(1,2).reshape(B,T,-1)
        # B,T dim
        return self.compute_output(x)

        

In [202]:
class FeedForward(nn.Module):

    def __init__(self,dim:int, hidden_dim:int, multiple_of:int, dropout:float):
        super().__init__()
        # print(dim,hidden_dim,multiple_of,dropout)
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim,hidden_dim,bias=False)
        self.w2 = nn.Linear(hidden_dim,dim,bias=False)
        self.w3 = nn.Linear(dim,hidden_dim,bias=False)
    def SiluGlu(self,x):
        return F.silu(self.w1(x)) * self.w3(x)
    def forward(self,x):
        return self.w2(self.SiluGlu(x))

In [203]:
class LlamaLayer(nn.Module):
    def __init__(self,layer_id:int,config:LlamaConfig):
        super().__init__()
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(config.dim,eps=0.00001)
        self.ffn_norm = RMSNorm(config.dim,eps=0.00001)
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config.dim,config.hidden_dim,config.multiple_of,config.dropout)
        
    def forward(self,x):
        x = x + self.attention(self.attention_norm(x))
        x = x + self.feed_forward(self.ffn_norm(x))
        return x

In [204]:
class Llama(LlamaPreTrainedModel):
    def __init__(self,config: LlamaConfig):
        super().__init__(config)
        self.layers = nn.ModuleList()
        for i in range(config.n_layers):
            self.layers.append(LlamaLayer(i,config))
        self.norm = RMSNorm(config.dim,eps=0.000001)
        self.output = nn.Linear(config.dim,config.vocab_size,bias=False)
        self.tok_embeddings = nn.Embedding(config.vocab_size,config.dim)
    def forward(self,x):
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        logits = self.output(x)
        return logits

    def generate(self,tokens,max_len):
        # B,T
        for _ in range(max_len):
            x = self.tok_embeddings(tokens) # B,T,dim
            logits = self.forward(x) # B,T,vocabsize
            logits = logits[:,-1,:] # B, vocabsize
            probs = F.softmax(logits,dim=-1)
            idx = torch.multinomial(probs,num_samples=1) # B,1
            tokens = torch.cat((tokens,idx),dim=-1)
        return  tokens
        
        

In [205]:
config = LlamaConfig()

In [206]:
config.dim

512

In [207]:
llamma = Llama(LlamaConfig())

In [208]:
tokens = torch.LongTensor([[1,2,3]])

In [209]:
tokens

tensor([[1, 2, 3]])

In [210]:
llamma.generate(tokens,10)

tensor([[    1,     2,     3, 26166,  2582, 11375, 10368, 31222, 17974,  9217,
          5680, 22261,  7229]])

In [211]:
def load_pretrained(checkpoint):
    print('hhhhh')
    device = 'cuda' if torch.cuda.is_available() else 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc.
    #dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
    dtype = "float32"
    
    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
    device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
    
    # init from a model saved in a specific directory
    checkpoint_dict = torch.load(checkpoint, map_location=device)
    config = LlamaConfig(**checkpoint_dict['model_args'])
    model = Llama(config)
    state_dict = checkpoint_dict['model']
    unwanted_prefix = '_orig_mod.'
    for k,v in list(state_dict.items()):
      if k.startswith(unwanted_prefix):
          state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    # print(len(state_dict),len(model.state_dict()))
    # print('state.........................................................')
    # for key in state_dict.keys():
    #     print(key)
    # print('mdoel.........................................................')
    # for key in model.state_dict().keys():
    #     print(key)
    model.load_state_dict(state_dict, strict=True)
    return model

In [212]:
import torch

seed = 1337
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

sanity_data = torch.load("./sanity_check.data")
# text_batch = ["hello world", "hello neural network for NLP"]
# tokenizer here
sent_ids = torch.tensor([[101, 7592, 2088, 102, 0, 0, 0, 0],
                         [101, 7592, 15756, 2897, 2005, 17953, 2361, 102]])

# load our model
print('load....')
llama = load_pretrained("stories42M.pt")
# with torch.no_grad():
#     logits, hidden_states = llama(sent_ids)
#     print('logits',logits)
#     print('santiy', sanity_data["logits"])
#     assert torch.allclose(logits, sanity_data["logits"], atol=1e-5, rtol=1e-3)
#     print(hidden_states, sanity_data["hidden_states"])
#     assert torch.allclose(hidden_states, sanity_data["hidden_states"], atol=1e-5, rtol=1e-3)
#     print("Your Llama implementation is correct!")

  sanity_data = torch.load("./sanity_check.data")
  checkpoint_dict = torch.load(checkpoint, map_location=device)


load....
hhhhh


In [213]:
llama.generate(sent_ids,20)

tensor([[  101,  7592,  2088,   102,     0,     0,     0,     0,  8227,  5832,
          1699,  1183,  1497, 29889,   376, 12024, 29915, 29879,   437,   372,
          1699,   263,  4802, 29892, 18881,  5076, 29889,   323],
        [  101,  7592, 15756,  2897,  2005, 17953,  2361,   102,   294,  6773,
           304, 11230, 29889,  2688,   674,   367, 28773,  7205, 19773, 29889,
          1205,   937, 29892,   306,  2367,   366,  1286,  3850]])

In [214]:
from tokenizer import Tokenizer

In [215]:
def generate_sentence(prefix, max_new_tokens = 75, temperature = 0.0):
    with torch.no_grad():
        device = 'cpu'
        llama = load_pretrained("stories42M.pt")
        llama = llama.to(device)
        print(f"load model")
        enc = Tokenizer(30)
    
        start_ids = enc.encode(prefix, bos=True, eos=False)
        x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
    
        # run generation
        with torch.no_grad():
            y = llama.generate(x, max_new_tokens)
            sentence = enc.decode(y[0].tolist())
            print(f"Temperature is {temperature}")
            print(sentence)
            print('---------------')

In [217]:
generate_sentence('She is')

hhhhh


  checkpoint_dict = torch.load(checkpoint, map_location=device)


load model
Temperature is 0.0
She is a big bird. He likes to fly very fast. He looks for children to fly in the forest. He talks with the children to fly again. He pushes the other birds and joins them to fly too. He looks at them with different animals on altogether. He thanks them with the skies around him.
One day, hearsoks and see.
---------------
