In [1]:
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json

In [None]:
tokenizer_path = "../model/tokenizer.model"
special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(
    name=Path(tokenizer_path).name,
    pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
    mergeable_ranks=mergeable_ranks,
    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)

In [4]:
with open("../model/params.json", "r") as f:
    config = json.load(f)

In [None]:
model = torch.load("../model/consolidated.00.pth", map_location=torch.device('cpu'))
embd = torch.nn.Embedding(tokenizer.n_vocab, config['dim'])
embd.load_state_dict({'weight': model['tok_embeddings.weight']})

In [None]:
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens_i = [128000] + tokenizer.encode(prompt)
tokens_i = torch.tensor(tokens_i)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens_i]
print(prompt_split_as_tokens)

In [None]:
torch.cat((tokens_i, torch.tensor([1])))

In [5]:
# Implement ROPE
def apply_rope(x, start_pos=0):
    """Apply rotary positional encoding to queries or keys"""
    seq_len, head_dim = x.shape[-2], x.shape[-1]
    
    # Create base frequencies
    freqs = 1.0 / (config['rope_theta'] ** (torch.arange(0, head_dim, 2).float() / head_dim))
    
    # Create frequency for each position
    freqs_for_each_token = torch.outer(torch.arange(start_pos, start_pos + seq_len), freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
    
    # Apply rotation
    x_split = x.float().view(*x.shape[:-1], -1, 2)
    x_complex = torch.view_as_complex(x_split)
    x_rotated = x_complex * freqs_cis
    return torch.view_as_real(x_rotated).flatten(-2).type_as(x)

In [None]:
SEQ_LEN = 1024
EMBD_DIM = 512
kv_cache = [{
    'k': torch.zeros([SEQ_LEN, EMBD_DIM]),
    'v': torch.zeros([SEQ_LEN, EMBD_DIM])
} for i in range(config['n_layers'])]

tokens = tokens_i.clone()
x = embd(tokens)

In [None]:
is_prefill_stage = True
kv_counter = 0
while tokens[-1] < 128000:
    with torch.no_grad():
        x = embd(tokens)

        if not is_prefill_stage:
            x = x[-1:, :]
    
        seq_len = x.shape[0]
        head_dim = config['dim'] // config['n_heads']

        for layer in range(config['n_layers']):
            # RMS 1
            rms_1 = torch.nn.functional.rms_norm(x, normalized_shape=(x.shape[-1],), weight=model[f"layers.{layer}.attention_norm.weight"], eps=config["norm_eps"])
            
            # GQA
            xq = rms_1 @ torch.transpose(model[f"layers.{layer}.attention.wq.weight"].type(torch.float32), 0, 1)
            xk = rms_1 @ torch.transpose(model[f"layers.{layer}.attention.wk.weight"].type(torch.float32), 0, 1)
            xv = rms_1 @ torch.transpose(model[f"layers.{layer}.attention.wv.weight"].type(torch.float32), 0, 1)

            kv_cache[layer]['k'][kv_counter:kv_counter+seq_len] = xk
            kv_cache[layer]['v'][kv_counter:kv_counter+seq_len] = xv

            if not is_prefill_stage:
                xk = torch.concat((kv_cache[layer]['k'][:kv_counter], xk), dim=0)
                xv = torch.concat((kv_cache[layer]['v'][:kv_counter], xv), dim=0)

            xq = xq.view(seq_len, config['n_heads'], head_dim).transpose(0, 1).contiguous()
            xk = xk.view(kv_counter+seq_len, config['n_kv_heads'], head_dim) \
                .unsqueeze(2) \
                .expand(-1, -1, 4, -1) \
                .flatten(1, 2) \
                .transpose(0, 1) \
                .contiguous()
            xv = xv.view(kv_counter+seq_len, config['n_kv_heads'], head_dim) \
                .unsqueeze(2) \
                .expand(-1, -1, 4, -1) \
                .flatten(1, 2) \
                .transpose(0, 1) \
                .contiguous()

            xq = apply_rope(xq, start_pos=kv_counter)
            xk = apply_rope(xk)
            
            logits = (xq @ xk.transpose(-2, -1)) / (head_dim**0.5)

            mask = torch.triu(
                torch.full((kv_counter+seq_len, kv_counter+seq_len), float('-inf')),
                diagonal=1
            )
            attn_i = (logits + mask[-seq_len:, :]).softmax(dim=-1)
            if not is_prefill_stage:
                attn_i = attn_i[:, -1:, :]
            attn = attn_i @ xv
            attn_o = attn.transpose(0, 1).reshape(seq_len, -1) @ model[f"layers.{layer}.attention.wo.weight"].type(torch.float32).transpose(0, 1)
            
            # Residuals
            x_2 = x + attn_o
            
            # RMS 2
            rms_2 = torch.nn.functional.rms_norm(x_2, normalized_shape=(x_2.shape[-1],), weight=model[f"layers.{layer}.ffn_norm.weight"], eps=config["norm_eps"])
            
            # FFN
            ffn_w1 = model[f"layers.{layer}.feed_forward.w1.weight"].type(torch.float32)
            ffn_w2 = model[f"layers.{layer}.feed_forward.w2.weight"].type(torch.float32)
            ffn_w3 = model[f"layers.{layer}.feed_forward.w3.weight"].type(torch.float32)
            ffn_output = (torch.nn.functional.silu(rms_2 @ torch.transpose(ffn_w1, 0, 1)) * (rms_2 @ torch.transpose(ffn_w3, 0, 1))) @ torch.transpose(ffn_w2, 0, 1)
            
            # Residuals
            x = x_2 + ffn_output

        kv_counter += seq_len

        rms_f = torch.nn.functional.rms_norm(x, normalized_shape=(x.shape[-1],), weight=model["norm.weight"], eps=config["norm_eps"])
        linear_f = rms_f @ torch.transpose(model["output.weight"].type(torch.float32), 0, 1)
        out = torch.nn.functional.softmax(linear_f, dim=-1)[-1]
        values, indices = torch.topk(out, k=1)
        next_token = indices[0].item()
        print(tokenizer.decode([next_token]))
        tokens = torch.cat((tokens, torch.tensor([next_token])))
        is_prefill_stage = False