In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat

In [3]:
repo = "meta-llama/Llama-3.1-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

Loading safetensors:   0%|          | 0/4 [00:00<?, ?it/s]

Loading safetensors: 100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


In [None]:
import jax
import jax.numpy as jnp


def init_cache(config):
    return [{"k": None, "v": None} for _ in range(config["num_layers"])]


tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)

initial_text = "Hello!"
sequence = tokenizer.encode(initial_text).ids

cache = init_cache(config)

print("Initial sequence: " + tokenizer.decode(sequence))

max_new_tokens = 20

key = jax.random.PRNGKey(0)
temperature = 0.6

model_input = jnp.array([sequence], dtype=jnp.int32) 

for step in range(max_new_tokens):
    lengths = jnp.array([[len(sequence)]], dtype=jnp.int32)

    output, cache = transformer.run_decoder(model_input, lengths, cache, params, config)

    logits = output[0, -1, :]

    if temperature != 1.0:
        logits = logits / temperature

    key, subkey = jax.random.split(key)

    next_token = int(jax.random.categorical(subkey, logits))
    model_input = next_token
    if hasattr(tokenizer, "eos_token_id") and next_token == tokenizer.eos_token_id:
        break

    sequence.append(next_token)

    decoded_text = tokenizer.decode(sequence)
    print("Step {}: {}".format(step + 1, decoded_text))


Initial sequence: Hello!
Step 1: Hello!://
Step 2: Hello!://://
Step 3: Hello!://://://
Step 4: Hello!://://://://
Step 5: Hello!://://://://://
Step 6: Hello!://://://://://://
Step 7: Hello!://://://://://://://
Step 8: Hello!://://://://://://://://
Step 9: Hello!://://://://://://://://://
Step 10: Hello!://://://://://://://://://://
Step 11: Hello!://://://://://://://://://://://
Step 12: Hello!://://://://://://://://://://://://
Step 13: Hello!://://://://://://://://://://://://://
Step 14: Hello!://://://://://://://://://://://://://://


KeyboardInterrupt: 