In [None]:
# There is a small bug in this caching impl compared to non-caching and I can't figure out where.

In [None]:
from pathlib import Path
import torch
import json
from transformers import PreTrainedTokenizerFast
from safetensors.torch import load_file

device = torch.device("cpu")
torch.set_float32_matmul_precision('high')

In [23]:
# run `HF_HUB_ENABLE_HF_TRANSFER=1 uv run hf download LiquidAI/LFM2-1.2B --local-dir /Users/omkaarwork/Desktop/projects/models-from-scratch/liquid-lsm2-1.2b/model`
tokenizer_path = "./model"
tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)

In [24]:
with open("./model/config.json", "r") as f:
    config = json.load(f)

In [25]:
model = load_file("./model/model.safetensors", device=device.type)
embd = torch.nn.Embedding(config['vocab_size'], config['hidden_size'])
embd.load_state_dict({'weight': model[f'model.embed_tokens.weight']})
embd = embd.to(device)

In [26]:
messages = [
    {"role": "system", "content": 'Follow the instructions.'},
    {"role": "user",   "content": "Repeat this sentence 'hi this is lfm' and end your turn"}
]
tokens_i = tokenizer.apply_chat_template(
    messages, tokenize=True, add_generation_prompt=True
)
tokens_i = torch.tensor(tokens_i, device=device)
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens_i]
print(prompt_split_as_tokens)

['<|startoftext|>', '<|im_start|>', 'system', '\n', 'Follow', ' the', ' instructions', '.', '<|im_end|>', '\n', '<|im_start|>', 'user', '\n', 'Re', 'peat', ' this', ' sentence', " '", 'hi', ' this', ' is', ' l', 'fm', "'", ' and', ' end', ' your', ' turn', '<|im_end|>', '\n', '<|im_start|>', 'assistant', '\n']


In [31]:
def apply_rope(x: torch.Tensor, start: int = 0):
    _, H, S, D =  x.shape
    freqs = 1 / config['rope_theta'] ** (torch.arange(0, D, 2, device=x.device, dtype=torch.float32) / D)
    positions = torch.arange(start, start + S, device=x.device, dtype=torch.float32)
    freqs_per_token = torch.outer(positions, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs_per_token), freqs_per_token)
    x_pairs = x.view(*x.shape[:-1], -1, 2)
    x_complex = torch.view_as_complex(x_pairs)
    x_rotated = x_complex * freqs_cis
    return torch.view_as_real(x_rotated).flatten(-2)

In [32]:
tokens = tokens_i.clone()
kv_cache = {l: {'k': None, 'v': None} for l in config['full_attn_idxs']}
conv_cache = {l: None for l in range(config['num_hidden_layers']) if l not in config['full_attn_idxs']}

In [33]:
prefill_done = False

with torch.inference_mode():
    while tokens[-1].item() != tokenizer.eos_token_id:

        x = embd(tokens).unsqueeze(0) if not prefill_done else embd(tokens[-1:]).unsqueeze(0)

        for layer in range(config['num_hidden_layers']):
            S = x.shape[1]

            x_norm = torch.nn.functional.rms_norm(x, normalized_shape=(config['hidden_size'],), weight=model[f'model.layers.{layer}.operator_norm.weight'], eps=config['norm_eps']).type(torch.bfloat16)

            if layer in config['full_attn_idxs']:  # Attention
                xq = x_norm @ model[f'model.layers.{layer}.self_attn.q_proj.weight'].type(torch.bfloat16).T
                xk = x_norm @ model[f'model.layers.{layer}.self_attn.k_proj.weight'].type(torch.bfloat16).T
                xv = x_norm @ model[f'model.layers.{layer}.self_attn.v_proj.weight'].type(torch.bfloat16).T

                xq = xq.view(1, S, config['num_attention_heads'], -1).transpose(-2, -3)
                xk = xk.view(1, S, config['num_key_value_heads'], -1).repeat_interleave(config['num_attention_heads'] // config['num_key_value_heads'], dim=2).transpose(-2, -3)
                xv = xv.view(1, S, config['num_key_value_heads'], -1).repeat_interleave(config['num_attention_heads'] // config['num_key_value_heads'], dim=2).transpose(-2, -3)

                start_idx = kv_cache[layer]['k'].shape[-2] if (prefill_done and kv_cache[layer]['k'] is not None) else 0
                xq = apply_rope(xq.type(torch.float32), start_idx)
                xk = apply_rope(xk.type(torch.float32), start_idx)

                xq = torch.nn.functional.rms_norm(xq, normalized_shape=(xq.shape[-1],), weight=model[f'model.layers.{layer}.self_attn.q_layernorm.weight'], eps=config['norm_eps'])
                xk = torch.nn.functional.rms_norm(xk, normalized_shape=(xk.shape[-1],), weight=model[f'model.layers.{layer}.self_attn.k_layernorm.weight'], eps=config['norm_eps'])
                
                if prefill_done and kv_cache[layer]['k'] is not None and S == 1:
                    Kprev = kv_cache[layer]['k']
                    Vprev = kv_cache[layer]['v']
                    q_last = xq[:, :, 0:, :]
                    score = (q_last @ Kprev.transpose(-1, -2)) / (Kprev.shape[-1] ** 0.5)
                    attn = torch.softmax(score, dim=-1) @ Vprev
                    x_operator = attn.to(torch.bfloat16).transpose(1, 2).reshape(1, 1, -1) @ model[f'model.layers.{layer}.self_attn.out_proj.weight'].type(torch.bfloat16).T
                    kv_cache[layer]['k'] = torch.cat((Kprev, xk), dim=-2)
                    kv_cache[layer]['v'] = torch.cat((Vprev, xv.float()), dim=-2)
                else:
                    score = ((xq @ xk.transpose(-1, -2)) / (xk.shape[-1] ** 0.5)).type(torch.float32) + torch.triu(torch.full((S, S), float('-inf'), device=x.device), diagonal=1)
                    attn = torch.softmax(score, dim=-1) @ xv.float()
                    x_operator = attn.to(torch.bfloat16).transpose(1, 2).reshape(1, S, -1) @ model[f'model.layers.{layer}.self_attn.out_proj.weight'].type(torch.bfloat16).T
                    kv_cache[layer]['k'] = xk
                    kv_cache[layer]['v'] = xv.float()
            else: # Conv layer
                # (1, S, D) @ (1, D, 3D) = (1, S, 3D) -> T -> (1, 3D, S)
                BCx = (x_norm @ model[f'model.layers.{layer}.conv.in_proj.weight'].type(torch.bfloat16).T).transpose(-1, -2)

                # (1, 3D, S) -> (1, D, S), (1, D, S), (1, D, S)
                B, C, x_c = BCx.chunk(3, dim=-2)

                # (1, D, S) * (1, D, S) -> (1, D, S)
                x_c = B * x_c
                pre_conv = x_c

                # (1, D, S) conv (D, 1, 3) -> (1, D, S + 2)
                if prefill_done and S == 1 and conv_cache[layer] is not None:
                    combined = torch.cat((conv_cache[layer], pre_conv), dim=-1)
                    y = torch.nn.functional.conv1d(
                        combined,
                        weight=model[f'model.layers.{layer}.conv.conv.weight'],
                        padding=0,
                        groups=config['hidden_size'],
                    )
                    x_c = y[:, :, -1:]
                    keep = config['conv_L_cache'] - 1
                    conv_cache[layer] = combined[:, :, -keep:] if keep > 0 else None
                else:
                    x_c = torch.nn.functional.conv1d(
                        pre_conv,
                        weight=model[f'model.layers.{layer}.conv.conv.weight'],
                        padding=config['conv_L_cache'] - 1,
                        groups=config['hidden_size'],
                    )
                    x_c = x_c[:, :, :S]
                    keep = config['conv_L_cache'] - 1
                    conv_cache[layer] = pre_conv[:, :, -keep:] if keep > 0 else None

                # (1, D, S) * (1, D, S) -> (1, D, S)
                x_c = C * x_c

                # (1, S, D) @ (D, D) -> (1, S, D)
                x_operator = x_c.transpose(-1, -2) @ model[f'model.layers.{layer}.conv.out_proj.weight'].type(torch.bfloat16).T

            x = x + x_operator

            x_norm = torch.nn.functional.rms_norm(x, normalized_shape=(config['hidden_size'],), weight=model[f'model.layers.{layer}.ffn_norm.weight'], eps=config['norm_eps']).type(torch.bfloat16)

            ffn_w1 = model[f"model.layers.{layer}.feed_forward.w1.weight"].type(torch.bfloat16)
            ffn_w2 = model[f"model.layers.{layer}.feed_forward.w2.weight"].type(torch.bfloat16)
            ffn_w3 = model[f"model.layers.{layer}.feed_forward.w3.weight"].type(torch.bfloat16)
            ffn_o = (torch.nn.functional.silu(x_norm @ torch.transpose(ffn_w1, 0, 1)) * (x_norm @ torch.transpose(ffn_w3, 0, 1))) @ torch.transpose(ffn_w2, 0, 1)

            x = x + ffn_o

        x = torch.nn.functional.rms_norm(x, normalized_shape=(config['hidden_size'],), weight=model['model.embedding_norm.weight'], eps=config['norm_eps']).type(torch.bfloat16)
        out = x @ model[f'model.embed_tokens.weight'].T

        next_id = out[:, -1, :].argmax(dim=-1)
        tokens = torch.cat((tokens, next_id.to(device)), dim=-1)
        print(tokenizer.decode([next_id.item()]), end='', flush=True)
        prefill_done = True


Hi, this is lfm.

(Note: The last part "Hi, this is lfm

KeyboardInterrupt: 