In [None]:
from pathlib import Path
import torch
import json

In [None]:
from transformers import PreTrainedTokenizerFast
from safetensors.torch import load_file

In [None]:
tokenizer_path = "../model"
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)

In [None]:
with open("../model/config.json", "r") as f:
    config = json.load(f)

In [None]:
tokenizer.decode([1, 2, 3]), tokenizer.vocab_size, config['vocab_size']

In [None]:
model = load_file("../model/model.safetensors", device='cpu')
embd = torch.nn.Embedding(config['vocab_size'], config['hidden_size'])
embd.load_state_dict({'weight': model[f'model.embed_tokens.weight']})

In [None]:
messages = [
    {"role": "system", "content": 'Only output 42'},
    {"role": "user",   "content": "Repeat after me, 42"}
]
tokens_i = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True
)
tokens_i = torch.tensor(tokens_i)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens_i]
print(prompt_split_as_tokens)

In [None]:
def apply_rope(x: torch.Tensor):
    h, S, d = x.shape[:]
    freqs = 1 / config['rope_theta'] ** (torch.arange(0, d, 2).float() / d)
    freqs_per_token = torch.outer(torch.arange(S, dtype=torch.float32), freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_per_token), freqs_per_token)
    x_pairs = x.float().view(*x.shape[:-1], -1, 2)
    x_complex = torch.view_as_complex(x_pairs)
    x_rotated = x_complex * freqs_cis
    return torch.view_as_real(x_rotated).flatten(-2).type_as(x)

In [None]:
for i in list(model.keys())[:20]:
    print(i)

In [None]:
tokens = tokens_i.clone()

In [None]:
for i in range(10):
    x = embd(tokens).type(torch.bfloat16)

    for layer in range(config['num_hidden_layers']):
        S = x.shape[0]

        rms_attn = torch.nn.functional.rms_norm(
            x, normalized_shape=(x.shape[-1],),
            weight=model[f'model.layers.{layer}.input_layernorm.weight'],
            eps=config['rms_norm_eps']
        )

        q = rms_attn @ model[f'model.layers.{layer}.self_attn.q_proj.weight'].T
        k = rms_attn @ model[f'model.layers.{layer}.self_attn.k_proj.weight'].T
        v = rms_attn @ model[f'model.layers.{layer}.self_attn.v_proj.weight'].T

        q = q.view(q.shape[0], config['num_attention_heads'], -1).transpose(0, 1).type(torch.bfloat16)
        k = k.view(k.shape[0], config['num_key_value_heads'], -1).repeat_interleave(config['num_attention_heads'] // config['num_key_value_heads'], dim=1).transpose(0, 1).type(torch.bfloat16)
        v = v.view(v.shape[0], config['num_key_value_heads'], -1).repeat_interleave(config['num_attention_heads'] // config['num_key_value_heads'], dim=1).transpose(0, 1).type(torch.bfloat16)

        q = apply_rope(q)
        k = apply_rope(k)

        q = torch.nn.functional.rms_norm(q, normalized_shape=(q.shape[-1],), weight=model[f'model.layers.{layer}.self_attn.q_norm.weight'], eps=config['rms_norm_eps'])
        k = torch.nn.functional.rms_norm(k, normalized_shape=(k.shape[-1],), weight=model[f'model.layers.{layer}.self_attn.k_norm.weight'], eps=config['rms_norm_eps'])

        score = (q.float() @ k.float().transpose(-1, -2)) / (k.shape[-1] ** 0.5)
        score = score + torch.triu(torch.full((S, S), float('-inf'), dtype=torch.float32), diagonal=1)
        attn = torch.softmax(score, dim=-1) @ v.float()
        out = attn.to(torch.bfloat16).transpose(0, 1).reshape(S, -1) @ model[f'model.layers.{layer}.self_attn.o_proj.weight'].T

        x = x + out

        rms_ffn = torch.nn.functional.rms_norm(
            x, normalized_shape=(x.shape[-1],),
            weight=model[f'model.layers.{layer}.post_attention_layernorm.weight'],
            eps=config['rms_norm_eps']
        )

        s = rms_ffn @ model[f'model.layers.{layer}.mlp.gate.weight'].T
        router_probs = torch.softmax(s.float(), dim=-1)
        k_top = int(config['num_experts_per_tok'])
        top_vals, top_idx = torch.topk(router_probs, k=k_top, dim=-1)
        if config.get('norm_topk_prob', False):
            top_vals = top_vals / top_vals.sum(dim=-1, keepdim=True)

        ffn = torch.zeros_like(rms_ffn)
        for token_idx in range(rms_ffn.shape[0]):
            acc = 0.0
            x_t = rms_ffn[token_idx]
            for j in range(k_top):
                e = int(top_idx[token_idx, j])
                up = torch.nn.functional.silu(x_t @ model[f'model.layers.{layer}.mlp.experts.{e}.gate_proj.weight'].T) * (x_t @ model[f'model.layers.{layer}.mlp.experts.{e}.up_proj.weight'].T)
                down = up @ model[f'model.layers.{layer}.mlp.experts.{e}.down_proj.weight'].T
                acc = acc + top_vals[token_idx, j].to(down.dtype) * down
            ffn[token_idx, :] = acc.to(rms_ffn.dtype)

        x = x + ffn

    rms_x = torch.nn.functional.rms_norm(x, normalized_shape=(x.shape[-1],), weight=model['model.norm.weight'], eps=config['rms_norm_eps'])
    out = rms_x @ model['lm_head.weight'].T

    out_softmax = torch.nn.functional.softmax(out[-1].float(), dim=-1)
    values, indices = torch.topk(out_softmax, k=1)
    tokens = torch.cat((tokens, indices), dim=-1)
    print(tokenizer.decode(indices))