## Gemma JAX Multi-Turn Chat Notebook

This notebook demonstrates running batched, multi-turn inference with the Gemma JAX model.

### Imports and other boilerplate


In [None]:
# Uncomment if using Colab

from google.colab import drive
drive.mount('/content/drive')

import os # Git clone the repository if it does not exist, and cd into it.
os.chdir('/content/drive/My Drive')

if not os.path.exists('gemma-jax'):
  !git clone https://github.com/baricev/gemma-jax

os.makedirs('gemma-jax', exist_ok=True)
os.chdir('gemma-jax')

print(f"Current working directory: {os.getcwd()}")

### Install

`!pip install jax[tpu] orbax datasets` -- quiet should also work

In [None]:
!pip install -e . --quiet

### Package Imports

Import core gemma-jax functions and datastructures.

In [3]:
import time
import argparse
from functools import partial
from pathlib import Path
import jax
from jax import Array
import jax.numpy as jnp

# Assuming gemma_jax is installed in editable mode (`pip install -e .`)
from gemma_jax.core.weights import (
    create_gemma3_config,
    create_device_mesh,
    load_model
)
from gemma_jax.core.model import (
    forward_fn,
    setup_scan_fn,
    scan_generate_step
)
from gemma_jax.core.cache import (
  KVCache,
  LayoutType,
  SEQUENCE_HEADS,
  HEADS_SEQUENCE,
  init_cache,
  layout_map,
)
from gemma_jax.core.rope import load_rope_cache
from gemma_jax.core.sp_tokenizer import SentencePieceTokenizer, process_and_pad_inputs, encode_text, decode_tokens, format_prompt
from gemma_jax.core.inference import greedy_sample

## Configuration Defaults

Set the default configuration values here. These replace the command-line arguments used in the script version.

**Important:** Update `CHECKPOINT_PATH` and `TOKENIZER_PATH` to your actual absolute paths.



In [None]:
root_dir = Path.cwd() # Assuming this notebook is in the `gemma_jax/examples` directory
checkpoint_path = root_dir / "4b"               # TODO: Replace with your ABSOLUTE path
tokenizer_path = root_dir / "tokenizer.model"   # TODO: Replace with your ABSOLUTE path

try:
  assert tokenizer_path.exists(), f"Tokenizer path {tokenizer_path} does not exist."
except AssertionError:
  # If the tokenizer path is not set, assume we are running a notebook in the examples directory
  root_dir = Path(__file__).parent.parent         # Adjust this if the notebook is in a different directory
  tokenizer_path = root_dir / "tokenizer.model"
  checkpoint_path= root_dir / "4b"              # TODO: Replace with your ABSOLUTE path

print(f"Using default tokenizer path: {tokenizer_path}")
print(f"Using default checkpoint path: {checkpoint_path}")

### Model Settings

In [5]:
model_size = 4              # Gemma model size (e.g., 4 for 4B). Choices: [1, 4, 12, 27]
cache_length = 2048         # KV cache length.
padded_input_size = 1024    # Padded input sequence length.
window_size = 1024          # Attention window size for sliding window attention.
batch_size = 4              # Batch size for inference.
generate_steps = 4          # Number of tokens to generate after prefill.
dtype_str = "bfloat16"       # Data type for model parameters. Choices: ['bfloat16', 'float16', 'float32']

dtype_map = {
    "bfloat16": jnp.bfloat16,
    "float16": jnp.float16,
    "float32": jnp.float32,
}
model_dtype = dtype_map[dtype_str]

## Setup: Initialization

The cell below initializes the tokenizer, model configuration, device mesh, loads the model parameters, and initializes the KV and RoPE caches.



In [None]:
print("Starting setup...")
start_setup = time.time()

# 1. Model Config
config = create_gemma3_config(
    model_size=model_size,
    batch_size=batch_size,
    padded_input_size=padded_input_size,
    cache_length=cache_length,
    window_size=window_size,
)
print(f"Model Config created for Gemma-{model_size}b")

# 2. Device Mesh
num_devices = len(jax.devices())
# TODO: Configure mesh shape
mesh = create_device_mesh((2, num_devices//2))
print(f"Device mesh created with shape: {mesh.shape}")


# 3. Load Model
assert checkpoint_path.exists(), f"Checkpoint path {checkpoint_path} does not exist."
assert checkpoint_path.is_absolute(), f"Checkpoint path {checkpoint_path} must be an absolute path."

print(f"Loading model from: {checkpoint_path} (dtype: {dtype_str})...")
load_start = time.time()
model = load_model(checkpoint_path, mesh, config, dtype=model_dtype)
print(f"Model loaded in {time.time() - load_start:.2f}s")

# 4. Initialize Caches
# rope_cache = load_rope_cache(mesh, config)  # RoPE cache dtype is float32 internally
rope_cache = None  # pass None to compute embeddings at runtime

# Configure memory layout, sharding or chache update functions in "cache.py"
# or use pre-configured settings (SEQUENCE_HEADS, HEADS_SEQUENCE)
cache_layout =  SEQUENCE_HEADS

cache = init_cache(
    mesh=mesh,
    config=config,
    dtype=jnp.bfloat16,
    kind=cache_layout,
    layout_map=layout_map,
)

print(f"Setup complete in {time.time() - start_setup:.2f}s")

# 4. Create Tokenizer
assert tokenizer_path.exists(), f"Tokenizer path {tokenizer_path} does not exist."
assert tokenizer_path.is_absolute(), f"Tokenizer path {tokenizer_path} must be an absolute path."
tokenizer = SentencePieceTokenizer(tokenizer_path)
print(f"Tokenizer loaded from: {tokenizer_path}")

print(f"Setup complete in {time.time() - start_setup:.2f}s")

## Input Processing, Prefill and Generate Functions

Tokenize and encode the input text. The tokenizer is a SentencePiece wrapper. Setup the prefill and auto-regressive stages using the functions defined above.

In [None]:

# Process inputs
process_partial = partial(
    process_and_pad_inputs,
    max_sequence_length=padded_input_size,
    cache_len=cache_length,
    tokenizer=tokenizer,
)

# Prefill inputs
prefill_partial = partial(
    forward_fn,
    write_index=0,
    model=model,
    cache=cache,
    rope_cache=rope_cache,
    config=config,
    layout=cache_layout,
)

# Auto-regressive generation
generate_partial = partial(
    scan_generate_step,
    model=model,
    rope_cache=rope_cache,
    config=config,
    layout=cache_layout,
)


# Example input text
input_text = [
    "I love to",
    "Explain general relativity to a first-century Roman philosopher (Cicero)",
    "Explain Cantor's proof of the uncountability of the reals to a Babylonian mathematician",
    "Why is the sky blue?",
]


prompt = ["Explaing evolution to a first-century Roman philosopher (Cicero)"]
prompt += ["Explain the significance of the following quote `The only thing we have to fear is fear itself`"]
prompt += ["I love to"] * batch_size  # Repeat the input text for the batch size
input_text = prompt[:batch_size]  # Ensure the input text matches the batch size

input_text = [format_prompt(text) for text in input_text]  # Format for Gemma 3 dialogue
ids = encode_text(input_text, tokenizer, add_bos_token_only=True)
input_text, ids.shape, ids

# Process and pad inputs
raw_input_ids = encode_text(input_text, tokenizer)
attn_mask = raw_input_ids != 0
padded_input_ids, padded_position_ids, cache_attn_mask =  process_partial(input_text)

print(f"Raw Input IDs: {raw_input_ids.shape}")
print(f"Attention Mask: {attn_mask.shape}")
print(f"Padded Input IDs shape: {padded_input_ids.shape}")
print(f"Position IDs shape: {padded_position_ids.shape}")
print(f"Cache attention mask shape: {cache_attn_mask.shape}")


## Inference: Prefill + Auto-Regressive Generation

## Prefill
#
The `prefill_partial` function is used to prefill the model with the input tokens. It takes the padded input IDs, positions, and attention mask as input and returns the logits and updated cache.
#
Note: The cache object is updated in-place by prefill_partial

In [None]:
logits, cache = prefill_partial(
      padded_input_ids,
      positions=padded_position_ids,
      attn_mask=cache_attn_mask,
)

print(f"Logits shape: {logits.shape}")
print(f"Cache shape: {cache.shape}" if hasattr(cache, "shape") else f"Cache object: {type(cache)}")


## Generation

Setup the scan function with the model, cache, and other parameters using `setup_scan_fn`.

The `scan_generate_step` function is used to generate tokens in a loop. It takes the model, cache, and other parameters as input and returns the generated tokens and updated cache.

Note: Using different inputs will not trigger recompilation as long as they fit within the padded window size. Generation will still run at  post-warmup speed.

In [None]:
all_gen, current_index, current_pos, carry = setup_scan_fn(
    padded_input_ids,
    padded_position_ids,
    greedy_sample(logits,positions=padded_position_ids ),
    prefill_cache=cache,
    cache_length=config.cache_length,
)

# The carry tuple now includes the updated cache from prefill
carry, _ = jax.lax.scan(generate_partial, carry, xs=None, length=2048)

# Unpack final state after scan completes
generated_tokens, final_cache = carry[0], carry[-1]
print(f"Generated tokens shape: {generated_tokens.shape}")

## Decode and Format Output


Decode the generated tokens back to text using the tokenizer. This is generally the *slowest* step in the entire inference loop.
Note the raw model is

In [None]:
# Decode generated tokens
formatted_output = decode_tokens(generated_tokens, tokenizer, skip_special_tokens=True)

for i, output in enumerate(formatted_output):
    print(f"Output {i + 1}:\n{output}")
    print("-" * 80)


## JIT Compiled Inference

In [None]:
input_text = [
    "The most beautiful thing in the world is",
    "Do you know what MLA (multi head latent attention) is?. Recall that it is a recent innovation in LLM architecture, \
        from DeepSeek labs. If you do, explain what it is. Make sure to focus on implementation details, please!",
     "Whats's your favorite Matsuo Basho poem? Translate it into Japanese.",
     "Tell me a joke that is not too funny, but still kind of funny. Something that would make a 5 year old laugh. Or Larry David from Curb Your Enthusiasm fame.",
]

input_text = [format_prompt(text) for text in input_text]  # Format for Gemma 3 dialogue

padded_input_ids, padded_position_ids, cache_attn_mask =  process_partial(input_text)

logits, cache = prefill_partial(
    padded_input_ids,
    positions=padded_position_ids,
    attn_mask=cache_attn_mask,
)

all_gen, current_index, current_pos, carry = setup_scan_fn(
    padded_input_ids,
    padded_position_ids,
    greedy_sample(logits,positions=padded_position_ids ),
    prefill_cache=cache,
    cache_length=config.cache_length,
)

carry, _ = jax.lax.scan(generate_partial, carry, xs=None, length=2048)
generated_tokens, final_cache = carry[0], carry[-1]

formatted_output = decode_tokens(generated_tokens, tokenizer, skip_special_tokens=True)
for i, output in enumerate(formatted_output):
    print(f"Output {i + 1}:\n{output}")
    print("-" * 80)
