In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat
from flax.core import FrozenDict

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
params = FrozenDict(params)
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


Cuda processing allowed: True


Loading safetensors: 100%|██████████| 4/4 [02:44<00:00, 41.18s/it]


In [141]:
import jax
import jax.numpy as jnp
from einops import rearrange
from functools import partial

def rms_norm(x, weight, eps=1e-6):
    orig_dtype = x.dtype
    x = x.astype(jnp.float32)
    normed = x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + eps)
    out = weight * normed.astype(orig_dtype)
    return out

def ffn(x, params, act_fn):
    gate = x @ params["W_gate"]
    act = act_fn(gate)
    up = x @ params["W_up"]
    output = (act * up) @ params["W_down"]
    return output


def build_rope(positions, head_dim, base):
    if head_dim % 2 != 0:
        raise ValueError(f"head_dim must be even, got {head_dim}")

    inv_freq = 1.0 / (
        base ** (jnp.arange(0, head_dim, 2, dtype=jnp.float32) / head_dim)
    )
    positions = positions.astype(jnp.float32)
    freqs = positions[:, :, None] * inv_freq[None, None, :]
    emb = jnp.concatenate((freqs, freqs), axis=-1)

    pad_mask = positions >= 0
    pad_mask = pad_mask[:, :, None].repeat(head_dim, axis=-1)
    emb = jnp.where(pad_mask, emb, 0.0)

    sin_values = jnp.sin(emb)
    cos_values = jnp.cos(emb)

    return sin_values, cos_values


def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return jnp.concatenate((-x2, x1), axis=-1)


def rope(x, sin, cos):
    if x.ndim == 4 and sin.ndim == 3:
        sin = sin[:, :, None, :]
        cos = cos[:, :, None, :]
    elif x.ndim > sin.ndim and x.shape[-1] == sin.shape[-1]:
        num_broadcast_dims = x.ndim - sin.ndim
        new_shape = list(sin.shape)
        for _ in range(num_broadcast_dims):
            new_shape.insert(-1, 1) 
        sin = jnp.reshape(sin, new_shape)
        cos = jnp.reshape(cos, new_shape)
        if sin.shape[:-1] != x.shape[:-1] or cos.shape[:-1] != x.shape[:-1]:
            try:
                sin = sin[..., None, :]
                cos = cos[..., None, :]
            except IndexError:
                raise ValueError(
                    f"Cannot broadcast sin/cos shapes {sin.shape} to x shape {x.shape}"
                )

    rotated_x = (x * cos) + (rotate_half(x) * sin)
    return rotated_x.astype(x.dtype)

@partial(jax.vmap, in_axes=(0, 0, 0, 0, 0))
def update_2d(arr1, arr2, update1, update2, start_idx):
    arr1_update = jax.lax.dynamic_update_slice_in_dim(arr1, update1, start_idx, axis=0)
    arr2_update = jax.lax.dynamic_update_slice_in_dim(arr2, update2, start_idx, axis=0)
    return arr1_update, arr2_update

def attention(inputs, cache, params, config):
    positions = cache.positions
    seq_lens = jnp.max(positions, axis=-1).astype(jnp.int32) + 1

    sin, cos = cache.sin, cache.cos

    query = inputs @ params["W_q"]
    key = inputs @ params["W_k"]
    value = inputs @ params["W_v"]
    
    query = rearrange(query, "... t (n h) -> ... t n h", n=config["num_heads"])
    query = rope(query, sin, cos)
    key = rearrange(key, "... t (n h) -> ... t n h", n=config["num_kv_heads"])
    key = rope(key, sin, cos)
    value = rearrange(value, "... t (n h) -> ... t n h", n=config["num_kv_heads"])


    full_key, full_value = key, value
    query_seq_lens = seq_lens
    if cache.keys is not None:
        full_key, full_value = update_2d(cache.keys, cache.values, key, value, cache.cached_lens)
        # query_seq_lens = jnp.ones((inputs.shape[0],)).astype(jnp.int32)

    x = jax.nn.dot_product_attention(
        query=query,
        key=full_key,
        value=full_value,
        is_causal=cache.keys is None,
        query_seq_lengths=query_seq_lens,
        key_value_seq_lengths=seq_lens,
        implementation="cudnn",
    )

    x = rearrange(x, "... t n h -> ... t (n h)")
    x = x @ params["W_o"]

    return x, cache.replace(keys=full_key, values=full_value)

@partial(jax.jit, static_argnums=(3,))
def run(inputs, cache, params, config):
    x = jnp.take(params["embed_table"], inputs, axis=0, fill_value=-1e6)

    new_layer_cache = []

    for i, layer_params in enumerate(params["layers"]):
        y = rms_norm(x, layer_params["input_norm"], eps=config["norm_eps"])
        attn_out, layer_cache = attention(y, cache.layers[i], layer_params["attn"], config)
        new_layer_cache.append(layer_cache)

        x = x + attn_out
        y = rms_norm(x, layer_params["post_attn_norm"], eps=config["norm_eps"])
        ffn_out = ffn(y, layer_params["ffn"], config["act_fn"])
        x = x + ffn_out

    x = rms_norm(x, params["out_norm"], eps=config["norm_eps"])
    logits = x @ params["lm_head"]

    if cache.use_kv:
        cache = cache.replace(layers=new_layer_cache)
    return logits, cache

In [None]:
from tqdm.auto import tqdm
from flax import struct
from typing import Optional, List

@struct.dataclass
class LayerCache:
    sin: jnp.array
    cos: jnp.array
    cached_lens: jnp.array 
    positions: jnp.array
    keys: Optional[jnp.array] = None
    values: Optional[jnp.array] = None


@struct.dataclass
class TransformerCache:
    use_kv: bool = struct.field(pytree_node=False)
    layers: List[LayerCache]
    full_positions: jnp.array
    full_sin: jnp.array
    full_cos: jnp.array

    @classmethod
    def initialize(
        cls, batch_size, current_positions, config, max_total_length=0, use_kv=False
    ):

        head_dim = config["hidden_size"] // config["num_heads"]
        positions = jnp.arange(max_total_length).astype(jnp.int32)
        positions = jnp.broadcast_to(positions, (batch_size, max_total_length))
        sin, cos = build_rope(positions, head_dim, config["rope_base"])
            
        layers = [
            LayerCache(
                sin=sin,
                cos=cos,
                cached_lens=jnp.zeros((batch_size,)).astype(jnp.int32),
                positions=current_positions,
                keys=None,
                values=None,
            )
            for _ in range(config["num_layers"])
        ]
        return cls(layers=layers, use_kv=use_kv, full_sin=sin, full_cos=cos, full_positions=positions)

    def roll(self):
        batch_indices = jnp.arange(self.full_positions.shape[0]).astype(jnp.int32)
        first_layer = self.layers[0]
        seq_lens = jnp.max(first_layer.positions, axis=-1).astype(jnp.int32) + 1

        full_positions = expand_and_set(self.full_positions, seq_lens, seq_lens, fill=-1)

        if self.use_kv:
            cached_lens = seq_lens
            new_positions = full_positions[batch_indices, seq_lens][..., None]
            new_sin = self.full_sin[batch_indices, seq_lens][:, None, :]
            new_cos = self.full_cos[batch_indices, seq_lens][:, None, :]
        else:
            cached_lens = first_layer.cached_lens
            new_positions = full_positions
            new_sin = self.full_sin
            new_cos = self.full_cos

        new_layers = []
        for layer in self.layers:
            new_layers.append(layer.replace(positions=new_positions,
                cached_lens=cached_lens, sin=new_sin, cos=new_cos))
                        
        return self.replace(layers=new_layers, full_positions=full_positions)


def expand_and_set(arr, indices, values, fill):
    while jnp.max(indices) >= arr.shape[-1]:
        filler = fill * jnp.ones((*arr.shape[:-1], 1))
        arr = jnp.concatenate([arr, filler], axis=-1).astype(arr.dtype)
    arr = arr.at[jnp.arange(arr.shape[0]), indices].set(values)
    return arr


def generate(inputs, max_new_tokens, tokenizer, params, config):
    batch_size = len(inputs)

    encodings = tokenizer.encode_batch_fast(inputs)
    tokens = jnp.array([enc.ids for enc in encodings])
    tokens = jnp.concatenate([tokens, tokenizer.pad_token_id * jnp.ones((batch_size, max_new_tokens))], axis=-1).astype(jnp.int32)
    positions = jnp.where(tokens != tokenizer.pad_token_id, jnp.arange(tokens.shape[-1]), -1)
    seq_lens = jnp.sum(positions >= 0, axis=-1)

    model_inputs = tokens

    cache = TransformerCache.initialize(
        batch_size=batch_size,
        current_positions=positions,
        config=config,
        max_total_length=jnp.max(seq_lens + max_new_tokens),
        use_kv=True,
    )

    for _ in tqdm(range(max_new_tokens)):
        logits, cache = run(model_inputs, cache, params, config)
        next_token_logits = logits[jnp.arange(batch_size), seq_lens-1, :]
        next_tokens = jnp.argmax(next_token_logits, axis=-1)        
        tokens = expand_and_set(tokens, seq_lens, next_tokens, fill=-1)
        model_inputs = next_tokens[..., None] if cache.use_kv else tokens
        cache = cache.roll()
        seq_lens += 1

    return tokenizer.decode_batch(
        jnp.where(tokens >= 0, tokens, tokenizer.pad_token_id)
    )

prompts = [
    "What is a Josephson junction?",
]

generate(prompts, 10000, tokenizer, params, config)

  0%|          | 0/10000 [00:00<?, ?it/s]

In [93]:
import jax
import jax.numpy as jnp
from functools import partial


# Corrected in_axes: change the last element from None to 0
@partial(jax.vmap, in_axes=(0, 0, 0, 0, 0))
def update_2d(arr1, arr2, update1, update2, start_idx):
    # Inside the vmapped function for a single batch element:
    # arr1, arr2 shape: (seq_len, num_heads, head_dim) -> e.g., (128, 32, 64)
    # update1, update2 shape: (1, num_heads, head_dim) -> e.g., (1, 32, 64)
    # start_idx: integer -> e.g., 1
    arr1_update = jax.lax.dynamic_update_slice_in_dim(arr1, update1, start_idx, axis=0)
    arr2_update = jax.lax.dynamic_update_slice_in_dim(arr2, update2, start_idx, axis=0)
    return arr1_update, arr2_update


# Batch size is 2
batch_size = 2
seq_len = 128
num_heads = 32
head_dim = 64

cache = jnp.zeros((batch_size, seq_len, num_heads, head_dim))
updates = jnp.ones((batch_size, 1, num_heads, head_dim))

# Provide one start index per batch element
# In this case, both are 1, but they could be different, e.g., jnp.array([1, 5])
start_indices = jnp.array([1, 1])

# Check shapes
print("cache shape:", cache.shape)
print("updates shape:", updates.shape)
print("start_indices shape:", start_indices.shape)

# Call the function
updated_cache1, updated_cache2 = update_2d(
    cache, cache, updates, updates, start_indices
)

# Check output shapes and content
print("\nupdated_cache1 shape:", updated_cache1.shape)
print("updated_cache2 shape:", updated_cache2.shape)

# Verify the update (optional)
print("\nValue at index [0, 0, 0, 0] (should be 0):", updated_cache1[0, 0, 0, 0])
print("Value at index [0, 1, 0, 0] (should be 1):", updated_cache1[0, 1, 0, 0])
print("Value at index [0, 2, 0, 0] (should be 0):", updated_cache1[0, 2, 0, 0])

print("\nValue at index [1, 0, 0, 0] (should be 0):", updated_cache1[1, 0, 0, 0])
print("Value at index [1, 1, 0, 0] (should be 1):", updated_cache1[1, 1, 0, 0])
print("Value at index [1, 2, 0, 0] (should be 0):", updated_cache1[1, 2, 0, 0])


cache shape: (2, 128, 32, 64)
updates shape: (2, 1, 32, 64)
start_indices shape: (2,)

updated_cache1 shape: (2, 128, 32, 64)
updated_cache2 shape: (2, 128, 32, 64)

Value at index [0, 0, 0, 0] (should be 0): 0.0
Value at index [0, 1, 0, 0] (should be 1): 1.0
Value at index [0, 2, 0, 0] (should be 0): 0.0

Value at index [1, 0, 0, 0] (should be 0): 0.0
Value at index [1, 1, 0, 0] (should be 1): 1.0
Value at index [1, 2, 0, 0] (should be 0): 0.0
