# Testing Gemma3 SigLip in Penzai

This colab shows how to load and conduct model forward of SigLip in Gemma3
series using our new package `gemma_penzai`. The original Penzai does not
support vision transformers. The current version extends such support.

NOTE: we run this colab on a TPU **v5e-1** runtime. Please see our notebook
`./notebooks/gemma3_multimodal_penzai.ipynb` on how to build a local runtime.

## Import packages

Firstly, we install `jax[tpu]`, `gemma_penzai` package and its dependencies.

In [None]:
# Clone the gemma_penzai package
!git clone https://github.com/google-deepmind/gemma_penzai.git

# Upgrade your pip in case
!pip install --upgrade pip

# Installs JAX with TPU support
!pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

# Install the package in editable mode (-e)
# This installs dependencies defined in your pyproject.toml
print("Installing gemma_penzai and dependencies...")
%cd gemma_penzai
!pip install -e .

Import miscellaneous packages.

In [None]:
import gc
import os
from IPython.display import clear_output
import kagglehub
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds

Import JAX related packages.

In [None]:
import jax
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec
import orbax.checkpoint

# check whether connects to TPU
jax.devices()

Import `penzai` related packages (NOTE: we use the most up-to-dated version).

In [None]:
from penzai import pz
import treescope

treescope.basic_interactive_setup(autovisualize_arrays=True)

Import `gemma_penzai` package to use Gemma3 models.

In [None]:
from gemma_penzai import vision

process_images = vision.image_utils.process_images
gemma_vision_from_pretrained_checkpoint = (
    vision.siglip.gemma_vision_from_pretrained_checkpoint
)

## Loading Gemma3 models

### Load and shard model parameters

You can download the Gemma checkpoints using a Kaggle account and an API key. If
you don't have an API key already, you can:

1.  Visit https://www.kaggle.com/ and create an account if needed.

2.  Go to your account settings, then the 'API' section.

3.  Click 'Create new token' to download your key.

Next, input your "KAGGLE_USERNAME" and "KAGGLE_KEY" below.

In [None]:
KAGGLE_USERNAME = "<KAGGLE_USERNAME>"
KAGGLE_KEY = "<KAGGLE_KEY>"
try:
  kagglehub.config.set_kaggle_credentials(KAGGLE_USERNAME, KAGGLE_KEY)
except ImportError:
  kagglehub.login()

We load Gemma3-4B instruction model. The checkpoint path could be found in
[Gemma's Documentation](https://gemma-llm.readthedocs.io/en/latest/checkpoints.html).
Please note that only Gemma3 4B / 12B / 27B have the vision module, which is the
same across different models.

In [None]:
weights_dir = kagglehub.model_download("google/gemma-3/flax/gemma3-4b-it")
clear_output()

In [None]:
ckpt_path = os.path.join(weights_dir, "gemma3-4b-it")
checkpointer = orbax.checkpoint.PyTreeCheckpointer()
metadata = checkpointer.metadata(ckpt_path)

Prepare the devices.

In [None]:
n_devices = jax.local_device_count()
sharding_devices = mesh_utils.create_device_mesh((n_devices,))
mesh = Mesh(sharding_devices, ("data",))

Sharding the model parameters. The following sharding strategy splits model
parameters into TPUs according to the last dimension.

In [None]:
restore_args = jax.tree_util.tree_map(
    lambda m: orbax.checkpoint.ArrayRestoreArgs(
        restore_type=jax.Array,
        sharding=NamedSharding(
            mesh, PartitionSpec(*(None,) * (len(m.shape) - 1), "data")
        ),
    ),
    metadata.item_metadata,  # change back to metadata if any running error
)
flat_params = checkpointer.restore(ckpt_path, restore_args=restore_args)

### Bind with Penzai model

Now we prepare the SigLip model definition and bind it with the parameters.

In [None]:
vision_model = gemma_vision_from_pretrained_checkpoint(
    flat_params,
)

### Model visualization

Directly visualizing the model definition with parameters will take a long time.
Therefore, we firstly use `unbind_params` function to extract the model
architecture. Then we only visualize the model architecture without parameters.

In [None]:
model_unbound, _ = pz.unbind_params(vision_model)
model_unbound

Now from the above visualization. We know the model class is
`SigLipFromPatches`, and it has `siglip_encoder` and `siglip_exit` two model
parts. `siglip_encoder` is an object of class `VisionTransformer`, while
`siglip_exit` is a sequence of model layers, which is used for downstream tasks.
In this case, `siglip_exit` downsamples the image tokens and also stop the
gradient (as we want to freeze parameters of the vision module).

Free some memory.

In [None]:
del flat_params
gc.collect()

## Test the model input/output for Gemma3 vision module.

### Prepare the inputs

First, let's load an image.

In [None]:
ds = tfds.data_source("oxford_flowers102", split="train")
image = ds[0]["image"]
clear_output()

Then we visualize this image, it is a flower.

In [None]:
plt.imshow(image)
plt.axis("off")  # Turn off axis labels
plt.show()

Visualize the image shape.

In [None]:
image.shape

It is noted that the image could be any size. We need to process image before
fed into the vision model. We provide `process_images` function to first resize
the image and then patchify it. Please note that we have such input `[[image]]`
to ensure the output has dimensions of `batch` (how many chat samples) and
`frame` (how many images in each chat sample).

In [None]:
images = process_images([[image]])
images = pz.nx.wrap(images).tag("batch", "frame", "patch", "embedding")

Check the dimensions for model input.

In [None]:
images.named_shape

In [None]:
images

### Model forward

Then we conduct the model forward.

In [None]:
out = vision_model(images)

Visualize the output.

In [None]:
out

As can be seen here, the image tokens are downsampled from 4k to 256. Then these
image tokens are stitched with text tokens.