# Testing Gemma3 text-only models in Penzai

This colab shows how to load and conduct model forward of Gemma3 text-only using
our new package `gemma_penzai`. The original Penzai only supports Gemma1 and
Gemma2 models. The current version extends such support. Additionally, we extend
the decoding methods with top-p and top-k sampling.

NOTE: we run this colab on a TPU **v5e-1** runtime. Please see our notebook
`./notebooks/gemma3_multimodal_penzai.ipynb` on how to build a local runtime.

## Import packages

Firstly, we install `jax[tpu]`, `gemma_penzai` package and its dependencies.

In [None]:
# Clone the gemma_penzai package
!git clone https://github.com/google-deepmind/gemma_penzai.git

# Upgrade your pip in case
!pip install --upgrade pip

# Installs JAX with TPU support
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install the package in editable mode (-e)
# This installs dependencies defined in your pyproject.toml
print("Installing gemma_penzai and dependencies...")
%cd gemma_penzai
!pip install -e .

Import miscellaneous packages.

In [None]:
import gc
import os
from gemma import gm
from IPython.display import clear_output
import kagglehub

Import JAX related packages.

In [None]:
import jax
import jax.numpy as jnp
import orbax.checkpoint

# check whether connects to TPU
jax.devices()

Import `penzai` related packages (NOTE: we use the most up-to-dated version).

In [None]:
from penzai import pz
from penzai.toolshed import jit_wrapper
from penzai.toolshed import token_visualization
import treescope

treescope.basic_interactive_setup(autovisualize_arrays=True)

Import `gemma_penzai` package to use Gemma3 models.

In [None]:
from gemma_penzai import mllm

gemma_from_pretrained_checkpoint = (
    mllm.load_gemma.gemma_from_pretrained_checkpoint
)
sampling_mode = mllm.sampling_mode
simple_decoding_loop = mllm.simple_decoding_loop

## Load Gemma3 models from Penzai

### Load model parameters

You can download the Gemma checkpoints using a Kaggle account and an API key. If
you don't have an API key already, you can:

1.  Visit https://www.kaggle.com/ and create an account if needed.

2.  Go to your account settings, then the 'API' section.

3.  Click 'Create new token' to download your key.

Next, input your "KAGGLE_USERNAME" and "KAGGLE_KEY" below.

In [None]:
KAGGLE_USERNAME = "<KAGGLE_USERNAME>"
KAGGLE_KEY = "<KAGGLE_KEY>"
try:
  kagglehub.config.set_kaggle_credentials(KAGGLE_USERNAME, KAGGLE_KEY)
except ImportError:
  kagglehub.login()

We load Gemma3-4B instruction model. The checkpoint path could be found in
[Gemma's Documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).

In [None]:
weights_dir = kagglehub.model_download("google/gemma-3/flax/gemma3-4b-it")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma3-4b-it")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()

Here, we don't split model parameters. Optionally, we can shard parameters into
different devices.

In [None]:
flat_params = checkpointer.restore(ckpt_path)

### Bind with Penzai model

Now we prepare the Gemma3 language model definition and bind it with the
parameters.

In [None]:
model = gemma_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=False,
)

### Model visualization

Directly visualizing the model definition with parameters will take a long time.
Therefore, we firstly use `unbind_params` function to extract the model
architecture. Then we only visualize the model architecture without parameters.

In [None]:
model_unbound, _ = pz.unbind_params(model)
model_unbound

Free some memory.

In [None]:
del flat_params
gc.collect()

## Text generation for Gemma3 models

### Prepare the inputs

We directly use Gemma3 tokenizer to prepare the input. Here is the tokenizer:

In [None]:
tokenizer = gm.text.Gemma3Tokenizer()

The total number of tokens is available through .vocab_size:

In [None]:
tokenizer.vocab_size

As we utilize an instruction model, we prepare our prompt in a multi-turn style.

In [None]:
prompt = """<start_of_turn>user
Share one methapore linking "shadow" and "laughter".<end_of_turn>
<start_of_turn>model
"""

Then we tokenize the prompt into tokens. Please note that `add_bos=True` should
be explicitly passed as Gemma3 without <bos> cannot work normally.

In [None]:
token_ids = tokenizer.encode(prompt, add_bos=True)
token_ids

Then we transform token_ids to named jax arrays with named axes `batch` and
`seq`:

In [None]:
tokens = jnp.asarray(token_ids)[None, :]
tokens = pz.nx.wrap(tokens).tag("batch", "seq")
tokens

We can also use `token_visualization` in Penzai to visualize the input token
ids. Please note that `show_token_array` needs an argument of `SentencePiece`
object. To achieve this, we can pass `tokenizer._sp`.

In [None]:
token_visualization.show_token_array(tokens, tokenizer._sp)  # pylint: disable=protected-access

### Test model forward

Check model forward.

In [None]:
logits = model(tokens)

We take the last token logits.

In [None]:
last_logits_penzai = logits.untag("seq")[-1]
last_logits_penzai

### Prepare the model with KV cache

Before the inference, we first prepare an inference mode by adding KV cache.

In [None]:
inference_model = sampling_mode.KVCachingTransformerLM.from_uncached(
    model,
    cache_len=1024,
    batch_axes={"batch": 1},
)

### Greedy sampling

Then we jit the model and sample the output from the loop.

In [None]:
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    temperature=0.0,
    rng=jax.random.key(3),
    max_sampling_steps=512,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out1 = tokenizer.decode(sample_tokens)
penzai_out1

### Random sampling with temperature

Random sampling with temperature 0.8.

In [None]:
inference_model = sampling_mode.KVCachingTransformerLM.from_uncached(
    model,
    cache_len=1024,
    batch_axes={"batch": 1},
)
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    temperature=0.8,
    rng=jax.random.key(3),
    max_sampling_steps=512,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out2 = tokenizer.decode(sample_tokens)
penzai_out2

### Top-p Sampling

Top_p sampling with temperature 1.0 and top_p 0.95.

In [None]:
inference_model = sampling_mode.KVCachingTransformerLM.from_uncached(
    model,
    cache_len=1024,
    batch_axes={"batch": 1},
)
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    temperature=1.0,
    top_p=0.95,
    rng=jax.random.key(3),
    max_sampling_steps=512,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out3 = tokenizer.decode(sample_tokens)
penzai_out3

### Top-k sampling

Top_k sampling with temperature 1.0 and top_k 20.

In [None]:
inference_model = sampling_mode.KVCachingTransformerLM.from_uncached(
    model,
    cache_len=1024,
    batch_axes={"batch": 1},
)
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    temperature=1.0,
    top_k=20,
    rng=jax.random.key(3),
    max_sampling_steps=512,
)
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
penzai_out4 = tokenizer.decode(sample_tokens)
penzai_out4