In [1]:
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json

In [2]:
tokenizer_path = "../model/tokenizer.model"
special_tokens = [
            "<|begin_of_text|>",
            "<|end_of_text|>",
            "<|reserved_special_token_0|>",
            "<|reserved_special_token_1|>",
            "<|reserved_special_token_2|>",
            "<|reserved_special_token_3|>",
            "<|start_header_id|>",
            "<|end_header_id|>",
            "<|reserved_special_token_4|>",
            "<|eot_id|>",  # end of turn
        ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]
mergeable_ranks = load_tiktoken_bpe(tokenizer_path)
tokenizer = tiktoken.Encoding(
    name=Path(tokenizer_path).name,
    pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
    mergeable_ranks=mergeable_ranks,
    special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)},
)

In [3]:
with open("../model/params.json", "r") as f:
    config = json.load(f)

In [4]:
model = torch.load("../model/consolidated.00.pth", map_location=torch.device('cpu'))
embd = torch.nn.Embedding(tokenizer.n_vocab, config['dim'])
embd.load_state_dict({'weight': model['tok_embeddings.weight']})

<All keys matched successfully>

In [5]:
prompt = "the answer to the ultimate question of life, the universe, and everything is "
tokens_i = [128000] + tokenizer.encode(prompt)
tokens_i = torch.tensor(tokens_i)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens_i]
print(prompt_split_as_tokens)

['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']


In [13]:
SEQ_LEN = 1024
EMBD_DIM = 512
tokens = tokens_i.clone()
x = embd(tokens)

In [None]:
def apply_rope(x, start_pos=0):
    """Apply rotary positional encoding to queries or keys"""
    seq_len, head_dim = x.shape[-2], x.shape[-1]
    print('pre-rope x shape:', x.shape)

    # create base frequencies
    freqs = 1.0 / (config['rope_theta'] ** (torch.arange(0, head_dim, 2).float() / head_dim))
    print('freqs shape:', freqs.shape)
    
    # create frequency for each position
    freqs_for_each_token = torch.outer(torch.arange(start_pos, start_pos + seq_len), freqs)
    print('freqs_for_each_token shape:', freqs_for_each_token.shape)

    freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
    print('freqs_cis shape:', freqs_cis.shape)
    
    # apply rotation
    x_split = x.float().view(*x.shape[:-1], -1, 2)
    print('x_split shape:', x_split.shape)

    x_complex = torch.view_as_complex(x_split)
    print('x_complex shape:', x_complex.shape)

    x_rotated = x_complex * freqs_cis
    print('x_rotated shape:', x_rotated.shape)

    out = torch.view_as_real(x_rotated).flatten(-2).type_as(x)
    print('out shape:', out.shape)

    return out

In [None]:
with torch.no_grad():
    x = embd(tokens)
    seq_len = x.shape[0]
    head_dim = config['dim'] // config['n_heads']

    for layer in range(1):
        # RMS 1
        rms_1 = torch.nn.functional.rms_norm(x, normalized_shape=(x.shape[-1],), weight=model[f"layers.{layer}.attention_norm.weight"], eps=config["norm_eps"])

        xq = rms_1 @ torch.transpose(model[f"layers.{layer}.attention.wq.weight"].type(torch.float32), 0, 1)
        xq = xq.view(seq_len, config['n_heads'], head_dim).transpose(0, 1).contiguous()

        xq = apply_rope(xq)

0.19216537475585938 ms
