In [1]:
%load_ext autoreload
%autoreload 2

In [271]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat
from flax.core import FrozenDict

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
params = FrozenDict(params)
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


Loading safetensors:   0%|          | 0/4 [00:00<?, ?it/s]

Loading safetensors: 100%|██████████| 4/4 [00:02<00:00,  1.54it/s]


In [285]:
import jax
import jax.numpy as jnp
from einops import rearrange
from functools import partial

def rms_norm(x, weight, eps=1e-6):
    orig_dtype = x.dtype
    x = x.astype(jnp.float32)
    normed = x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + eps)
    out = weight * normed.astype(orig_dtype)
    return out

def ffn(x, params, act_fn):
    gate = x @ params["W_gate"]
    act = act_fn(gate)
    up = x @ params["W_up"]
    output = (act * up) @ params["W_down"]
    return output


def build_rope(positions, head_dim, base):
    if head_dim % 2 != 0:
        raise ValueError(f"head_dim must be even, got {head_dim}")

    inv_freq = 1.0 / (
        base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    )
    positions = positions.astype(jnp.float32)
    freqs = positions[:, :, None] * inv_freq[None, None, :]
    emb = jnp.concatenate((freqs, freqs), axis=-1)

    pad_mask = positions >= 0
    pad_mask = pad_mask[:, :, None].repeat(head_dim, axis=-1)
    emb = jnp.where(pad_mask, emb, 0.0)

    sin_values = jnp.sin(emb)
    cos_values = jnp.cos(emb)

    return sin_values, cos_values


def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return jnp.concatenate((-x2, x1), axis=-1)


def rope(x, sin, cos):
    if x.ndim == 4 and sin.ndim == 3:
        sin = sin[:, :, None, :]
        cos = cos[:, :, None, :]
    elif x.ndim > sin.ndim and x.shape[-1] == sin.shape[-1]:
        num_broadcast_dims = x.ndim - sin.ndim
        new_shape = list(sin.shape)
        for _ in range(num_broadcast_dims):
            new_shape.insert(-1, 1) 
        sin = jnp.reshape(sin, new_shape)
        cos = jnp.reshape(cos, new_shape)
        if sin.shape[:-1] != x.shape[:-1] or cos.shape[:-1] != x.shape[:-1]:
            try:
                sin = sin[..., None, :]
                cos = cos[..., None, :]
            except IndexError:
                raise ValueError(
                    f"Cannot broadcast sin/cos shapes {sin.shape} to x shape {x.shape}"
                )

    rotated_x = (x * cos) + (rotate_half(x) * sin)
    return rotated_x.astype(x.dtype)

def attention(inputs, cache, params, config):
    positions = cache.positions
    seq_lens = jnp.sum(positions >= 0, axis=-1).astype(jnp.int32)

    sin = cache.sin[:, :positions.shape[1], :]
    cos = cache.cos[:, :positions.shape[1], :]

    query = inputs @ params["W_q"]
    key = inputs @ params["W_k"]
    value = inputs @ params["W_v"]
    
    query = rearrange(query, "... t (n h) -> ... t n h", n=config["num_heads"])
    query = rope(query, sin, cos)
    key = rearrange(key, "... t (n h) -> ... t n h", n=config["num_kv_heads"])
    key = rope(key, sin, cos)
    value = rearrange(value, "... t (n h) -> ... t n h", n=config["num_kv_heads"])

    x = jax.nn.dot_product_attention(
        query=query,
        key=key,
        value=value,
        is_causal=True,
        query_seq_lengths=seq_lens,
        key_value_seq_lengths=seq_lens,
        implementation="cudnn",
    )

    x = rearrange(x, "... t n h -> ... t (n h)")
    x = x @ params["W_o"]

    return x

@partial(jax.jit, static_argnums=(3,))
def run(inputs, cache, params, config):
    x = jnp.take(params["embed_table"], inputs, axis=0, fill_value=-1e6)

    for i, layer_params in enumerate(params["layers"]):
        y = rms_norm(x, layer_params["input_norm"], eps=config["norm_eps"])
        attn_out = attention(y, cache, layer_params["attn"], config)
        x = x + attn_out
        y = rms_norm(x, layer_params["post_attn_norm"], eps=config["norm_eps"])
        ffn_out = ffn(y, layer_params["ffn"], config["act_fn"])
        x = x + ffn_out

    x = rms_norm(x, params["out_norm"], eps=config["norm_eps"])
    logits = x @ params["lm_head"]
    return logits, cache

In [None]:
from tqdm.auto import tqdm
from flax import struct
from typing import Optional, List

@struct.dataclass
class LayerCache:
    sin: jnp.array = struct.field(pytree_node=False)
    cos: jnp.array = struct.field(pytree_node=False)
    keys: Optional[jnp.array] = None
    values: Optional[jnp.array] = None

@struct.dataclass
class TransformerCache:
    sin: jnp.array = struct.field(pytree_node=False)
    cos: jnp.array = struct.field(pytree_node=False)
    positions: jnp.array
    layers: Optional[List[LayerCache]] = None

    @classmethod
    def initialize(cls, batch_size, max_total_length, current_positions, config, use_kv=False):
        head_dim = config["hidden_size"] // config["num_heads"]
        max_positions = jnp.arange(max_total_length).astype(jnp.int32)
        max_positions = jnp.broadcast_to(
            max_positions, (batch_size, max_total_length)
        )
        sin, cos = build_rope(max_positions, head_dim, config["rope_base"])
        if use_kv:
            print("KV caching not implemented!")
        return cls(sin=sin, cos=cos, positions=current_positions)

    def roll(self):
        seq_lens = jnp.sum(self.positions >= 0, axis=-1)
        new_positions = expand_and_set(self.positions, seq_lens, seq_lens, fill=-1)
        return self.replace(positions=new_positions)


def expand_and_set(arr, indices, values, fill):
    while jnp.max(indices) >= arr.shape[-1]:
        filler = fill * jnp.ones((*arr.shape[:-1], 1))
        arr = jnp.concatenate([arr, filler], axis=-1).astype(arr.dtype)
    arr = arr.at[jnp.arange(arr.shape[0]), indices].set(values)
    return arr


def generate(inputs, max_new_tokens, tokenizer, params, config):
    batch_size = len(inputs)

    encodings = tokenizer.encode_batch_fast(inputs)
    tokens = jnp.array([enc.ids for enc in encodings])
    tokens = jnp.concatenate([tokens, tokenizer.pad_token_id * jnp.ones((batch_size, max_new_tokens))], axis=-1).astype(jnp.int32)
    positions = jnp.where(tokens != tokenizer.pad_token_id, jnp.arange(tokens.shape[-1]), -1)
    seq_lens = jnp.sum(positions >= 0, axis=-1)

    model_inputs = tokens

    cache = TransformerCache.initialize(batch_size=batch_size,
                                        max_total_length=max(seq_lens + max_new_tokens),
                                        current_positions=positions,
                                        config=config,
                                        use_kv=True)

    for step in tqdm(range(max_new_tokens)):
        logits, cache = run(model_inputs, cache, params, config)
        next_token_logits = logits[jnp.arange(batch_size), seq_lens-1, :]
        next_tokens = jnp.argmax(next_token_logits, axis=-1)

        model_inputs = expand_and_set(model_inputs, seq_lens, next_tokens, fill=-1)
        cache = cache.roll()
        seq_lens += 1

    return tokenizer.decode_batch(jnp.where(model_inputs >= 0, model_inputs, tokenizer.pad_token_id))

prompts = [
    "What is New York's",
    "What is the capital city of the United States of",
    "How many states are",
    "Write a haiku about how hard but rewarding it is to write JAX code!"
]

generate(prompts, 1000, tokenizer, params, config)

KV caching not implemented!


  0%|          | 0/1000 [00:00<?, ?it/s]

["What is New York's most popular tourist destination?\nNew York City is home to many iconic landmarks and attractions, but the most popular tourist destination in the state is actually Niagara Falls. Located on the border with Canada, Niagara Falls is a breathtaking natural wonder that attracts millions of visitors each year.\nNiagara Falls is a three-part waterfall system that consists of the American Falls, Bridal Veil Falls, and the Horseshoe Falls (also known as the Canadian Falls). The falls are surrounded by beautiful parks and gardens, and visitors can take in the views from various observation decks, including the Cave of the Winds tour, which takes you down into the Niagara Gorge for a thrilling up-close experience.\nIn addition to the natural beauty of the falls, the surrounding area offers a range of activities and attractions, including casinos, wineries, and outdoor adventures like hiking and whitewater rafting. Visitors can also take a scenic drive along the Niagara Scen