In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat

In [3]:
repo = "meta-llama/Llama-3.1-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

Loading safetensors:   0%|          | 0/4 [00:00<?, ?it/s]

Loading safetensors: 100%|██████████| 4/4 [00:08<00:00,  2.18s/it]


In [None]:
import jax
import jax.numpy as jnp
from flax.core import FrozenDict
from IPython.display import clear_output

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)

initial_text = "Bro..."
sequences = list(map(lambda x: x.ids, tokenizer.encode_batch([initial_text])))

max_new_tokens = 100
temperature = 0.3
key = jax.random.key(80)

model_inputs = jnp.array(sequences)
seq_lengths = jnp.array([len(seq) for seq in sequences])[..., None]

current_output = tokenizer.decode_batch(sequences)[0]

# Print the initial output
clear_output(wait=True)
print(current_output, end="", flush=True)

for i in range(max_new_tokens):
    output = transformer.run_decoder(model_inputs, seq_lengths, params, config)

    logits = output[:, -1, :]

    scaled_logits = logits / temperature

    probs = jax.nn.softmax(scaled_logits, axis=-1)

    step_key = jax.random.fold_in(key, i)
    next_tokens = jax.random.categorical(step_key, scaled_logits, axis=-1)

    model_inputs = jnp.concatenate([model_inputs, next_tokens[:, None]], axis=1)
    seq_lengths = seq_lengths + 1

    next_token_str = tokenizer.decode_batch(next_tokens[:, None], skip_special_tokens=False)[0]

    current_output += next_token_str

    # Clear the previous output and print the updated output
    clear_output(wait=True)
    print(current_output, end="", flush=True)

# After the loop, print a final newline to finalize the output.
print()
