In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import jax.numpy as jnp

from lmkit.model import transformer, config as config_lib
from lmkit.tools import compat

repo = "meta-llama/Meta-Llama-3-8B-Instruct"
model_dir = "models/llama3"

if not os.path.exists(model_dir) or not os.listdir(model_dir):
    from dotenv import load_dotenv
    load_dotenv()

    compat.from_hf(repo, model_dir, token=os.environ["HF_API_TOKEN"])

params = compat.params_to_lmkit(compat.gather_for_jax(model_dir))
config = compat.load_lmkit_config(f"{model_dir}/config.json")
config = config_lib.extend_llama(config)

tokenizer = compat.load_lmkit_tokenizer(
    f"{model_dir}/tokenizer.json", f"{model_dir}/generation_config.json"
)


Loading safetensors: 100%|██████████| 4/4 [00:08<00:00,  2.20s/it]


In [17]:
for seq_logits, last_idx in zip(output_logits, seq_lens-1):
    print(last_idx)
    print(seq_logits[last_idx].shape)


21
(128256,)
25
(128256,)
13
(128256,)


In [None]:
next_token_logits = [seq_logits[last_idx] for seq_logits, last_idx in zip(output_logits, seq_lens-1)]

In [27]:
import jax
import jax.numpy as jnp
from IPython.display import clear_output, display, HTML


# --- Initial Setup ---
initial_texts = [
    "Question: Please give me a haiku about writing JAX code!",
    "Question: Please give me an elaborate python code to train a transformer with pytorch.",
    "Question: Capital of France!",
]

encoded_batch = tokenizer.encode_batch(initial_texts)
model_inputs = jnp.array(list(map(lambda x: x.ids, encoded_batch)), dtype=jnp.int32)

# --- Calculate Initial Positions and Sequence Lengths ---
# Create position IDs: 0, 1, 2, ... for non-pad tokens, -1 for pad tokens
positions = jnp.where(
    model_inputs != tokenizer.pad_token_id,
    jnp.arange(model_inputs.shape[1], dtype=jnp.int32),  # Assign index 0, 1, 2...
    -1,  # Assign -1 for padding (or choose another indicator if your model expects something else)
)

seq_lens = jnp.sum(model_inputs != tokenizer.pad_token_id, axis=1, dtype=jnp.int32)

# --- Generation Parameters ---
batch_size = model_inputs.shape[0]
batch_indices = jnp.arange(batch_size)
max_new_tokens = 100
temperature = 0.3
key = jax.random.key(2002)

# --- Tracking ---
completed = [False for _ in range(batch_size)]
# Decode initial prompts correctly, handling padding
current_outputs = []
for i in range(batch_size):
    prompt_tokens = model_inputs[i][: seq_lens[i]]  # Get only non-pad tokens
    prompt_text = tokenizer.decode(prompt_tokens.tolist())  # Decode the actual prompt
    current_outputs.append(
        f"Prompt: {initial_texts[i]}\nCompletion: "
    )  # Use original text for clarity


# --- Initial Display ---
clear_output(wait=True)
for output in current_outputs:
    display(HTML(f"<div style='white-space: pre-wrap;'>{output}</div><hr>"))


# --- Generation Loop ---
for i in range(max_new_tokens):
    print(f"DEBUG: Step {i}, seq_lens: {seq_lens}")  # Optional debug print

    # --- Model Forward Pass ---
    output_logits = transformer.run_decoder(model_inputs, positions, params, config)

    # --- Select Logits for Next Token ---
    print(f"Indices: {seq_lens - 1}")
    logits = jnp.array([
        seq_logits[last_idx]
        for seq_logits, last_idx in zip(output_logits, seq_lens - 1)
    ])

    # --- Sample Next Token ---
    scaled_logits = logits / jnp.maximum(temperature, 1e-6)  # Add epsilon for safety
    step_key = jax.random.fold_in(key, i)
    # Ensure next_tokens are int32
    next_tokens = jax.random.categorical(step_key, scaled_logits, axis=-1).astype(
        jnp.int32
    )
    # Add sequence dimension: shape (batch_size,) -> (batch_size, 1)
    next_tokens_expanded = next_tokens[:, None]
    print(f"DEBUG: Step {i}, sampled tokens: {next_tokens}")  # Optional debug print

    # break

    # --- Update State ---
    # Check if we need to expand the arrays (common in simpler loops, might not be needed with KV caching)
    current_max_len = model_inputs.shape[1]
    if jnp.any(seq_lens >= current_max_len):
        # Expand arrays by one position if any sequence hits the current max length
        print(
            f"DEBUG: Step {i}, Expanding arrays from {current_max_len}"
        )  # Optional debug print
        # Pad model_inputs with PAD_TOKEN_ID
        model_inputs = jnp.concatenate(
            [
                model_inputs,
                jnp.full((batch_size, 1), tokenizer.pad_token_id, dtype=jnp.int32),
            ],
            axis=1,
        )
        # Pad positions with -1 (or your chosen padding indicator)
        positions = jnp.concatenate(
            [positions, jnp.full((batch_size, 1), -1, dtype=jnp.int32)], axis=1
        )

    # Indices to update: (batch_idx, sequence_idx) where sequence_idx is the current length
    update_indices = (batch_indices, seq_lens)
    model_inputs = model_inputs.at[update_indices].set(next_tokens)

    positions = positions.at[update_indices].set(seq_lens)

    seq_lens += 1

    # --- Decode and Check Completion ---
    next_tokens_list = next_tokens.tolist()  # Use the original (B,) shape tokens
    all_sequences_completed = True  # Assume all complete until proven otherwise
    for idx, token_id in enumerate(next_tokens_list):
        if not completed[idx]:
            token_str = tokenizer.decode(
                [token_id], skip_special_tokens=False
            )  # Decode single token
            current_outputs[idx] += token_str  # Append decoded string

            if token_id == tokenizer.eos_token_id:
                completed[idx] = True
                print(
                    f"DEBUG: Step {i}, Sequence {idx} completed (EOS)."
                )  # Optional debug print
            else:
                all_sequences_completed = (
                    False  # At least one sequence is still running
                )
        else:
            # Keep track if all were *already* complete
            pass

    # --- Display Update ---
    clear_output(wait=True)
    for output_idx, output in enumerate(current_outputs):
        completion_marker = " [COMPLETED]" if completed[output_idx] else ""
        display(
            HTML(
                f"<div style='white-space: pre-wrap;'>{output}{completion_marker}</div><hr>"
            )
        )

    # --- Check for Early Exit ---
    if all_sequences_completed or all(
        completed
    ):  # Check if all are newly or previously completed
        print(f"DEBUG: All sequences completed at step {i}.")
        break

print("\n--- Generation Finished ---")
# Final outputs are already stored in current_outputs


DEBUG: Step 94, seq_lens: [109]


: 