In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


Loading safetensors: 100%|██████████| 4/4 [06:40<00:00, 100.15s/it]


In [137]:
params["embed_table"].shape

(128256, 4096)

In [None]:
import jax
import jax.numpy as jnp
from einops import rearrange

def rms_norm(x, weight, eps=1e-6):
    orig_dtype = x.dtype
    x = x.astype(jnp.float32)
    normed = x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + eps)
    out = weight * normed.astype(orig_dtype)
    return out

def ffn(x, params, act_fn):
    gate = x @ params["W_gate"]
    act = act_fn(gate)
    up = x @ params["W_up"]
    output = (act * up) @ params["W_down"]
    return output


def build_rope(positions, head_dim, base):
    if head_dim % 2 != 0:
        raise ValueError(f"head_dim must be even, got {head_dim}")

    inv_freq = 1.0 / (
        base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    )
    positions = positions.astype(jnp.float32)
    freqs = positions[:, :, None] * inv_freq[None, None, :]
    emb = jnp.concatenate((freqs, freqs), axis=-1)

    pad_mask = positions >= 0
    pad_mask = pad_mask[:, :, None].repeat(head_dim, axis=-1)
    emb = jnp.where(pad_mask, emb, 0.0)

    sin_values = jnp.sin(emb)
    cos_values = jnp.cos(emb)

    return sin_values, cos_values


def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return jnp.concatenate((-x2, x1), axis=-1)


def rope(x, sin, cos):
    if x.ndim == 4 and sin.ndim == 3:
        sin = sin[:, :, None, :]
        cos = cos[:, :, None, :]
    elif x.ndim > sin.ndim and x.shape[-1] == sin.shape[-1]:
        num_broadcast_dims = x.ndim - sin.ndim
        new_shape = list(sin.shape)
        for _ in range(num_broadcast_dims):
            new_shape.insert(-1, 1) 
        sin = jnp.reshape(sin, new_shape)
        cos = jnp.reshape(cos, new_shape)
        if sin.shape[:-1] != x.shape[:-1] or cos.shape[:-1] != x.shape[:-1]:
            try:
                sin = sin[..., None, :]
                cos = cos[..., None, :]
            except IndexError:
                raise ValueError(
                    f"Cannot broadcast sin/cos shapes {sin.shape} to x shape {x.shape}"
                )

    rotated_x = (x * cos) + (rotate_half(x) * sin)
    return rotated_x.astype(x.dtype)

def attention(inputs, positions, params, config):
    head_dim = config["hidden_dim"] // config["num_heads"]
    seq_lens = jnp.sum(positions >= 0, axis=-1).astype(jnp.int32)
    sin, cos = build_rope(positions, head_dim, config["rope_base"])
    

    query = inputs @ params["W_q"]
    key = inputs @ params["W_k"]
    value = inputs @ params["W_v"]
    
    query = rearrange(query, "... t (n h) -> ... t n h", n=config["num_heads"])
    query = rope(query, sin, cos)
    key = rearrange(key, "... t (n h) -> ... t n h", n=config["num_kv_heads"])
    key = rope(key, sin, cos)
    value = rearrange(value, "... t (n h) -> ... t n h", n=config["num_kv_heads"])

    x = jax.nn.dot_product_attention(
        query=query,
        key=key,
        value=value,
        is_causal=True,
        query_seq_lengths=seq_lens,
        key_value_seq_lengths=seq_lens,
        implementation="cudnn",
    )

    x = rearrange(x, "... t n h -> ... t (n h)")
    x = x @ params["W_o"]

    return x

def run(inputs, positions, params, config):
    x = jnp.take(params["embed_table"], inputs, axis=0, fill_value=-1e6)

    for i, layer_params in enumerate(params["layers"]):
        y = rms_norm(x, layer_params["input_norm"], eps=config["norm_eps"])
        attn_out = attention(y, positions, layer_params["attn"], config)
        x = x + attn_out
        y = rms_norm(x, layer_params["post_attn_norm"], eps=config["norm_eps"])
        ffn_out = ffn(y, layer_params["ffn"], config["act_fn"])
        x = x + ffn_out

    x = rms_norm(x, params["out_norm"], eps=config["norm_eps"])
    logits = x @ params["lm_head"]
    return logits

texts = ["What is the capital of the United", "How are you today?"]
batch_size = len(texts)
tokens = jnp.array(list(map(lambda x: x.ids, tokenizer.encode_batch_fast(texts))))

positions = -jnp.ones(tokens.shape).astype(jnp.int32)
positions = jnp.where(tokens != tokenizer.pad_token_id, jnp.arange(positions.shape[-1]), -1)

print(positions)

logits = run(tokens, positions, params, config)

logits.shape

[[ 0  1  2  3  4  5  6  7]
 [ 0  1  2  3  4  5 -1 -1]]


(2, 8, 128256)

In [231]:
def expand_or_set_indices(arr, indices, fill):
    while jnp.max(indices) >= arr.shape[-1]:
        filler = fill * jnp.ones((*arr.shape[:-1], 1))
        arr = jnp.concatenate([arr, filler], axis=-1).astype(arr.dtype)
    arr = arr.at[jnp.arange(arr.shape[0]), indices].set(indices)
    return arr

expand_or_set_indices(jnp.array([[0,1,2,3,4,5], [0,1,2,3,-1,-1]]), jnp.array([6, 4]), -1)

Array([[ 0,  1,  2,  3,  4,  5,  6],
       [ 0,  1,  2,  3,  4, -1, -1]], dtype=int32)

In [237]:
from tqdm.auto import tqdm

def expand_and_set(arr, indices, values, fill):
    while jnp.max(indices) >= arr.shape[-1]:
        filler = fill * jnp.ones((*arr.shape[:-1], 1))
        arr = jnp.concatenate([arr, filler], axis=-1).astype(arr.dtype)
    arr = arr.at[jnp.arange(arr.shape[0]), indices].set(values)
    return arr


def generate(inputs, max_new_tokens, tokenizer, params, config):
    batch_size = len(inputs)

    encodings = tokenizer.encode_batch_fast(inputs)
    tokens = jnp.array([enc.ids for enc in encodings])
    positions = jnp.where(tokens != tokenizer.pad_token_id, jnp.arange(tokens.shape[-1]), -1)
    seq_lens = jnp.sum(positions >= 0, axis=-1)

    model_inputs = tokens

    for step in tqdm(range(max_new_tokens)):
        logits = run(model_inputs, positions, params, config)
        next_token_logits = logits[jnp.arange(batch_size), seq_lens-1, :]
        next_tokens = jnp.argmax(next_token_logits, axis=-1)

        model_inputs = expand_and_set(model_inputs, seq_lens, next_tokens, fill=-1)
        positions = expand_and_set(positions, seq_lens, seq_lens, fill=-1)
        seq_lens += 1

    return tokenizer.decode_batch(jnp.where(model_inputs >= 0, model_inputs, tokenizer.pad_token_id))

prompts = [
    "What is New York's",
    "What is the capital city of the United States of"
]

generate(prompts, 30, tokenizer, params, config)

  0%|          | 0/30 [00:00<?, ?it/s]

["What is New York's most popular tourist destination?\nThe Statue of Liberty and Ellis Island are among the most popular tourist destinations in New York City. However, the most popular tourist",
 'What is the capital city of the United States of America?\nA. Washington D.C.\nB. New York City\nC. Los Angeles\nD. Chicago\nAnswer: A\nExplanation: The']