In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


Loading safetensors: 100%|██████████| 4/4 [06:40<00:00, 100.15s/it]


In [130]:
from tqdm.auto import tqdm

def generate(texts, max_new_tokens, tokenizer, params, config):
    head_dim = config["hidden_size"] // config["num_heads"]
    batch_size = len(texts)

    encoded_batch = tokenizer.encode_batch(texts)
    padded_tokens = jnp.array(list(map(lambda x: x.ids, encoded_batch)), dtype=jnp.int32)

    mask = padded_tokens != tokenizer.pad_token_id
    real_tokens = [tokens[m].tolist() for tokens, m in zip(padded_tokens, mask)]

    max_current_len = padded_tokens.shape[-1]
    max_total_len = max_current_len + max_new_tokens

    model_inputs = padded_tokens


    full_positions = jnp.arange(max_total_len, dtype=jnp.int32)
    sin, cos = transformer.build_rope_cache(full_positions, head_dim, config["rope_base"])


    batch_positions = -jnp.ones((batch_size, max_current_len), dtype=jnp.int32)
    batch_positions = jnp.where(model_inputs != tokenizer.pad_token_id, jnp.arange(max_current_len), -1)
    seq_lens = jnp.sum(batch_positions >= 0, axis=-1)

    cache = transformer.TransformerCache(use_kv=True, sin=sin, cos=cos,
                                         positions=batch_positions,
                                         keys=None, values=None)
    
    for step in tqdm(range(max_new_tokens)):
        batch_logits, cache = transformer.run_decoder(model_inputs, cache, params, config)
        next_token_ids = [None for _ in range(batch_size)]


        for i, logits in enumerate(batch_logits):
            logits_idx = seq_lens[i] - 1 if step == 0 else 0
            next_token_ids[i] = jnp.argmax(logits[logits_idx])
            real_tokens[i].append(next_token_ids[i])
        
        next_token_ids = jnp.array(next_token_ids).astype(jnp.int32)[..., None]

        model_inputs = next_token_ids
        seq_lens += 1
        cache = cache.next()

    return tokenizer.decode_batch(real_tokens, skip_special_tokens=False)

initial_texts = [
    "Question: What is New York's largest",
    "Question: Give me a haiku about how hard but rewarding writing JAX code is!",
]

max_new_tokens = 7
output_batch = generate(initial_texts, max_new_tokens, tokenizer, params, config)

for batch_item in output_batch:
    print(batch_item)

Cuda processing allowed: True


  0%|          | 0/7 [00:00<?, ?it/s]

TypeError: cannot reshape array of shape (1,) (size 1) into shape (2,) (size 2)