# Testing Gemma3 multimodal models in Penzai

This colab shows how to load and conduct model forward of Gemma3 multimodal
models using our new package `gemma_penzai`. The original Penzai only supports
Gemma1 and Gemma2 models. The current version extends supports for both vision
models and vision language models, e.g., Gemma3.

NOTE: we run this colab on a TPU **v5e-4** runtime and **v5e-1** runtime
(default one) is not enough.

We provide a step-by-step tutorial on how to setup a TPU v5e-4 runtime on [Google Cloud Platform (GCP)](https://cloud.google.com/):
1. Visit your [Google Cloud Platform (GCP)](https://cloud.google.com/) console and search for TPU, which is under "Virtual machines".
2. Create a TPU node with TPU type "v5litepod-4" and TPU software version "v2-alpha-tpuv5-lite". Please set the TPU node name, description, and zone following your preferences.
3. Open a terminal to connect your TPU node by running the command and forward the port (usually `8888`):
    ```
    gcloud compute tpus tpu-vm ssh <YOUR_TPU_NAME> \
    --zone=<YOUR_ZONE> \
    --project=<YOUR_PROJECT_ID> \
    -- -L 8888:localhost:8888
    ```

Afterwards, it is recommended to create a virtual environment to install `jupyter`. By default, Python version is 3.10, but we need at least Python 3.12 version. Follow the below steps:

1. Install Python 3.12 on your TPU node by running the command:
    ```
    sudo apt-get update
    sudo add-apt-repository ppa:deadsnakes/ppa
    sudo apt-get update
    sudo apt-get install python3.12 python3.12-venv python3.12-dev
    ```
2. Create your virtual environment:
    ```
    python3.12 -m venv <YOUR_VENV_NAME>
    source <YOUR_VENV_NAME>/bin/activate
    ```
3. Install and start `Jupyter`:
    ```
    pip install jupyterlab
    jupyter lab --port=8888 --no-browser
    ```
4. Look at the output in your terminal. Copy paste a URL that looks like this: `http://localhost:8888/lab?token=abc123...` to the Colab `Connect -> Connect to a local runtime`.

## Import packages

Firstly, we install `jax[tpu]`, `gemma_penzai` package and its dependencies.

In [None]:
# Clone the gemma_penzai package
!git clone https://github.com/google-deepmind/gemma_penzai.git

# Upgrade your pip in case
!pip install --upgrade pip

# Installs JAX with TPU support
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install the package in editable mode (-e)
# This installs dependencies defined in your pyproject.toml
print("Installing gemma_penzai and dependencies...")
%cd gemma_penzai
!pip install -e .

Import miscellaneous packages.

In [None]:
import gc
import os
from gemma import gm
from IPython.display import clear_output
import kagglehub
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

Import JAX related packages.

In [None]:
import jax
from jax.experimental import mesh_utils
import jax.numpy as jnp
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
import orbax.checkpoint

# check whether connects to TPU
jax.devices()

Import `penzai` related packages (NOTE: we use the most up-to-dated version).

In [None]:
from penzai import pz
from penzai.toolshed import jit_wrapper
import treescope

treescope.basic_interactive_setup(autovisualize_arrays=True)

Import `gemma_penzai` package to use Gemma3 models.

In [None]:
from gemma_penzai import mllm
from gemma_penzai import vision

process_images = vision.image_utils.process_images
gemma_multimodal_from_pretrained_checkpoint = (
    mllm.load_gemma.gemma_multimodal_from_pretrained_checkpoint
)
sampling_mode = mllm.sampling_mode
simple_decoding_loop = mllm.simple_decoding_loop

By default, Jax do not utilize the full GPU memory, but this can be overwritten.

In [None]:
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"

## Loading Gemma3 multimodal model from Penzai

### Load and shard parameters

You can download the Gemma checkpoints using a Kaggle account and an API key. If
you don't have an API key already, you can:

1.  Visit https://www.kaggle.com/ and create an account if needed.

2.  Go to your account settings, then the 'API' section.

3.  Click 'Create new token' to download your key.

Next, input your "KAGGLE_USERNAME" and "KAGGLE_KEY" below.

In [None]:
KAGGLE_USERNAME = "<KAGGLE_USERNAME>"
KAGGLE_KEY = "<KAGGLE_KEY>"
try:
  kagglehub.config.set_kaggle_credentials(KAGGLE_USERNAME, KAGGLE_KEY)
except ImportError:
  kagglehub.login()

We load Gemma3-4B instruction model. The checkpoint path could be found in
[Gemma's Documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).
Please note that only Gemma3 4B / 12B / 27B have the vision module, which is the
same across different models.

In [None]:
weights_dir = kagglehub.model_download("google/gemma-3/flax/gemma3-4b-it")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma3-4b-it")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

We prepare the sharding devices.

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(sharding_devices, ("data",))

As multimodal design may require large memory usage, we may need more TPUs
comparing to normal setup (in this case, we use 4x2 TPU v3). Therefore, we
adopt an advanced sharding strategy by splitting model parameters according to
the dimension which could be divided by the number of TPUs. The strategy is
defined as the following function.

In [None]:
def get_flexible_sharding(
    array_metadata,
    num_devices: int,
) -> NamedSharding:
  """Determines the sharding for an array based on divisibility by num_devices, starting from the last dimension and shifting left if not divisible.

  Args:
      array_metadata: An object with 'shape' attribute (e.g., ArrayMetadata from
        Orbax).
      num_devices: The number of devices (e.g., TPUs).

  Returns:
      A NamedSharding object for the array.
  """
  shape = array_metadata.shape
  num_dims = len(shape)

  if num_dims == 0:  # Scalar, no sharding needed
    return NamedSharding(mesh, PartitionSpec())

  # Iterate from the last dimension backwards
  for i in range(num_dims - 1, -1, -1):
    if shape[i] % num_devices == 0:
      # If divisible, shard on this dimension
      sharding_spec = [None] * num_dims
      sharding_spec[i] = "data"
      return NamedSharding(mesh, PartitionSpec(*sharding_spec))

  # If no dimension is divisible, fall back to sharding the last dimension
  # (or the first non-singleton if all are singletons, or just don't shard)
  # For now, we'll default to the original strategy if no dimension is perfectly divisible
  # You might want a different fallback depending on your specific needs (e.g., replicate)
  print(
      f"No perfectly divisible dimension found for shape {shape}. Sharding last"
      " dimension anyway."
  )
  return NamedSharding(mesh, PartitionSpec(*(None,) * (num_dims - 1), "data"))

Use our defined function `get_flexible_sharding` to shard the model parameters.

In [None]:
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=get_flexible_sharding(m, n_devices),
    ),  # Only apply to Array Metadata
    metadata.item_metadata,
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

### Bind with Penzai model

Now we prepare the Gemma3 multimodal language model definition in Penzai and
bind it with the sharded parameters.

In [None]:
model = gemma_multimodal_from_pretrained_checkpoint(
    flat_params,
    upcast_activations_to_float32=False,
)

### Model visualization

Directly visualizing the model definition with parameters will take a long time.
Therefore, we firstly use `unbind_params` function to extract the model
architecture. Then we only visualize the model architecture without parameters.

In [None]:
model_unbound, _ = pz.unbind_params(model)
model_unbound

Now from the above visualization. We know the model class is
`MultiModalTransformerLM`, and it has `vision_transformer`, `vision_projection`,
and `body` three model parts.

`vision_transformer` is an object of class `SigLipFromPatches`,
`vision_projection` projects image tokens into the space of text embeddings,
while `body` refers to the main body of language model.

Free some memory.

In [None]:
del flat_params
gc.collect()

## Evaluate the inference of Gemma3 models in Penzai

### Prepare the inputs

Load tokenizer for Gemma3 models

In [None]:
tokenizer = (
    gm.text.Gemma3Tokenizer()
)  # use gm.text.Gemma2Tokenizer() for Gemma 2 models.

tokenizer

In [None]:
tokenizer.vocab_size

In [None]:
tokenizer.special_tokens

To prepare the prompt, we need the instruction mode. And we need to add
<start_of_image> special tokens where the images should be inserted.

In [None]:
prompt = """<start_of_turn>user
What can you say about this image:

<start_of_image>

<end_of_turn>
<start_of_turn>model
"""

In [None]:
prompt = tokenizer.encode(prompt, add_bos=True)
prompt = jnp.asarray(prompt)[None, :]
prompt

In [None]:
tokens = pz.nx.wrap(prompt).tag("batch", "seq")

In [None]:
tokens

Then we load an image of flower.

In [None]:
ds = tfds.data_source("oxford_flowers102", split="train")
image = ds[0]["image"]

Visualize the image using pyplot.

In [None]:
plt.imshow(image)
plt.axis("off")  # Turn off axis labels
plt.show()

It is noted that the image could be any size. We need to process image before
fed into the vision model. We provide `process_images` function to first resize
the image and then patchify it. Please note that we have such input `[[image]]`
to ensure the output has dimensions of `batch` (how many chat samples) and
`frame` (how many images in each chat sample).

In [None]:
images = process_images([[image]])
images = pz.nx.wrap(images).tag("batch", "frame", "patch", "embedding")

In [None]:
images.named_shape

### Prepare the model with KV cache

Similar to text-only model, before the inference, we prepare an inference mode
by adding KV cache. We also need to pass the number of tokens for each image,
can use `model.metadata`.

In [None]:
inference_model = sampling_mode.KVCachingTransformerMultiModalLM.from_uncached(
    model,
    cache_len=1024,
    batch_axes={"batch": 1},
    num_tokens_per_image=model.metadata.num_tokens_per_image,
)

Then we jit the model and sample the output from the loop, same as text-only
model. For demonstration, we employ a greedy decoding approach.

### Text generation based on multimodal inputs

In [None]:
samples = simple_decoding_loop.temperature_sample_pyloop(
    (
        pz.select(inference_model)
        .at(lambda root: root.body)
        .apply(jit_wrapper.Jitted)
    ),
    prompt=tokens,
    images=images,
    # temperature=1.0,
    rng=jax.random.key(3),
    max_sampling_steps=512,
)

Finally, we decode the output.

In [None]:
sample_tokens = samples.untag("batch", "seq").unwrap()[0]
tokenizer.decode(sample_tokens)