# Finetune PaliGemma

> *These models and code are not official Google products and were trained and released for research purposes.*


**This notebook shows how to finetune PaliGemma 2 on a vision-language task.**
The training data consists of 90 pairs of images and long captions describing them.
To make it runnable on a T4 colab runtime with 16GB HBM and 12GB RAM, we opt to only finetune the attention layers of the language model and freeze the other parameters.

 **This setup is illustrative**. In a real usecase, the amount of data, trainable parameters, training steps and hyper-parameters and obtained results could be significantly different.

This notebook uses the model reference implementation from [big_vision](https://github.com/google-research/big_vision).
and shows how to:

 * Install deps, download model checkpoint and training data.
 * Load the model onto GPU devices.
 * Prepare the input to the model for training and inference.
 * Finetune the model and inspect output in validation split.

## Setup

In [1]:
# @title Fetch big_vision code and install dependencies.
import os
import sys

# TPUs with
if "COLAB_TPU_ADDR" in os.environ:
  raise "It seems you are using Colab with remote TPUs which is not supported."

# Fetch big_vision repository if python doesn't know about it and install
# dependencies needed for this notebook.
!rm /home/dougljia@amd.com/big_vision/big_vision/configs/proj/paligemma/big_vision_repo -rf #Revise path on a different machine
if not os.path.exists("big_vision_repo"):
  # Use local repository
  !ln -s /home/dougljia@amd.com/big_vision big_vision_repo  ##Need to change this on a different machine.

# Append big_vision code to python import path
if "big_vision_repo" not in sys.path:
  sys.path.append("big_vision_repo")

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "overrides" "ml_collections" "einops~=0.7" "sentencepiece"

### Configure your API key to access Kaggle

To use PaliGemma, you must provide your Kaggle username and a Kaggle API key.

1. To generate a Kaggle API key, go to the **Account** tab of your Kaggle user profile and select **Create New Token**. This will trigger the download of a `kaggle.json` file containing your API credentials.
1. In Colab, select **Secrets** (ðŸ”‘) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.

To be able to download, you will also need to acknowledge the Terms and Conditions of the PaliGemma on:

* https://www.kaggle.com/models/google/paligemma/



In [2]:
import os
# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json

# os.environ["KAGGLE_USERNAME"] = '*************'
# os.environ["KAGGLE_KEY"] = '*************'

# ! export KAGGLE_API_TOKEN='*************'

# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid OOM'ing due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

In [3]:
# import kagglehub

# # Authenticate
# kagglehub.login() # This will prompt you for your credentials.

In [4]:
!ls

README.md		  logs		  paligemma_tokenizer.model
big_vision_repo		  longcap100	  profile
convergence_study	  paligemma.png   transfers
finetune_paligemma.ipynb  paligemma2.png


In [5]:
# @title Download checkpoint, tokenizer and dataset to local filesystem.
#
import os
import kagglehub

# Use these for PaliGemma-2 3B 224pxÂ²
LLM_VARIANT = "gemma2_2b"
MODEL_PATH = "./paligemma2-3b-pt-224.b16.npz"
KAGGLE_HANDLE = "google/paligemma-2/jax/paligemma2-3b-pt-224"  # Path to fetch from Kaggle.

# Use these for PaliGemma 1:
# LLM_VARIANT = "gemma_2b"
# MODEL_PATH = "./paligemma-3b-pt-224.f16.npz"
# KAGGLE_HANDLE = "google/paligemma/jax/paligemma-3b-pt-224"

if not os.path.exists(MODEL_PATH):
  print("Downloading the checkpoint from Kaggle, this could take a few minutes....")
  MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)
  print(f"Model path: {MODEL_PATH}")

TOKENIZER_PATH = "./paligemma_tokenizer.model"
if not os.path.exists(TOKENIZER_PATH):
  print("Downloading the model tokenizer...")
  !gsutil cp gs://big_vision/paligemma_tokenizer.model {TOKENIZER_PATH}
  print(f"Tokenizer path: {TOKENIZER_PATH}")

DATA_DIR="./longcap100"
if not os.path.exists(DATA_DIR):
  print("Downloading the dataset...")
  !gsutil -m -q cp -n -r gs://longcap100/ .
  print(f"Data path: {DATA_DIR}")

Downloading the checkpoint from Kaggle, this could take a few minutes....
Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz
Model path: /root/.cache/kagglehub/models/google/paligemma-2/jax/paligemma2-3b-pt-224/1/./paligemma2-3b-pt-224.b16.npz


## Notebook

In [6]:
import base64
import functools
import html
import io
import os
import warnings

import jax
import jax.numpy as jnp
import numpy as np
import ml_collections

import tensorflow as tf
import sentencepiece

from IPython.display import display, HTML
from PIL import Image

# Import model definition from big_vision
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns

# Import big vision utilities
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding

# Don't let TF use the GPU or TPUs
tf.config.set_visible_devices([], "GPU")
tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

2025-12-02 07:21:17.594088: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764660077.603264   53465 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764660077.605810   53465 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1764660077.614182   53465 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764660077.614190   53465 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1764660077.614192   53465 computation_placer.cc:177] computation placer alr

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

2025-12-02 07:21:19.285535: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)
=== Source Location Trace: ===
learning/45eac/tfrc/runtime/libtpu_init_utils.cc:287


JAX version:  0.6.2
JAX platform: gpu
JAX devices:  8


In [7]:
# @title Construct model and load params into RAM.

# Define model
# IMPORTANT: Gemma-2 has a "final_logits_softcap" property, we set it to 0.0
# for better transfer results.
model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})
model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(TOKENIZER_PATH)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

In [8]:
# @title Move params to GPU/TPU memory.
#
# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune
# a part of the parameters. Additionally we keep the frozen params in float16
# and cast trainable to float32.

# Create a pytree mask of the trainable params.
def is_trainable_param(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn/"):  return True
  if name.startswith("llm/"):              return True
  if name.startswith("img/"):              return True
  raise ValueError(f"Unexpected param name {name}")
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

#
# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

In [9]:
# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], trainable)
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

# Print params to show what the model is made of.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)

 == Model params == 
img/Transformer/encoder_norm/bias                                                (1152,)                float32
img/Transformer/encoder_norm/scale                                               (1152,)                float32
img/Transformer/encoderblock/LayerNorm_0/bias                                    (27, 1152)             float32
img/Transformer/encoderblock/LayerNorm_0/scale                                   (27, 1152)             float32
img/Transformer/encoderblock/LayerNorm_1/bias                                    (27, 1152)             float32
img/Transformer/encoderblock/LayerNorm_1/scale                                   (27, 1152)             float32
img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias                             (27, 4304)             float32
img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel                           (27, 1152, 4304)       float32
img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias                             (2

In [10]:
# @title Define preprocess functions to create inputs to the model.

def preprocess_image(image, size=224):
  # Model has been trained to handle images of different aspects ratios
  # resized to 224x224 in the range [-1, 1]. Bilinear and antialias resize
  # options are helpful to improve quality in some tasks.
  image = np.asarray(image)
  if image.ndim == 2:  # Convert image without last channel into greyscale.
    image = np.stack((image,)*3, axis=-1)
  image = image[..., :3]  # Remove alpha layer.
  assert image.shape[-1] == 3

  image = tf.constant(image)
  image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)
  return image.numpy() / 127.5 - 1.0  # [0, 255]->[-1,1]

def preprocess_tokens(prefix, suffix=None, seqlen=None):
  # Model has been trained to handle tokenized text composed of a prefix with
  # full attention and a suffix with causal attention.
  separator = "\n"
  tokens = tokenizer.encode(prefix, add_bos=True) + tokenizer.encode(separator)
  mask_ar = [0] * len(tokens)    # 0 to use full attention for prefix.
  mask_loss = [0] * len(tokens)  # 0 to not use prefix tokens in the loss.

  if suffix:
    suffix = tokenizer.encode(suffix, add_eos=True)
    tokens += suffix
    mask_ar += [1] * len(suffix)    # 1 to use causal attention for suffix.
    mask_loss += [1] * len(suffix)  # 1 to use suffix tokens in the loss.

  mask_input = [1] * len(tokens)    # 1 if its a token, 0 if padding.
  if seqlen:
    padding = [0] * max(0, seqlen - len(tokens))
    tokens = tokens[:seqlen] + padding
    mask_ar = mask_ar[:seqlen] + padding
    mask_loss = mask_loss[:seqlen] + padding
    mask_input = mask_input[:seqlen] + padding

  return jax.tree.map(np.array, (tokens, mask_ar, mask_loss, mask_input))

def postprocess_tokens(tokens):
  tokens = tokens.tolist()  # np.array to list[int]
  try:  # Remove tokens at and after EOS if any.
    eos_pos = tokens.index(tokenizer.eos_id())
    tokens = tokens[:eos_pos]
  except ValueError:
    pass
  return tokenizer.decode(tokens)


In [11]:
# @title Function to iterate over train and validation examples.
SEQLEN = 128

# TODO: Consider data iterators skipping big_vision and tf.data?
train_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_train90.jsonl"),
    fopen_keys={"image": DATA_DIR})

val_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join(DATA_DIR, "data_val10.jsonl"),
    fopen_keys={"image": DATA_DIR})


def train_data_iterator(seed=None):
  """Never ending iterator over training examples."""
  # Shuffle examples and repeat so one can train for many epochs.
  dataset = train_dataset.get_tfdata().shuffle(1_000, seed=seed).repeat()
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    suffix = example["suffix"].decode().lower()
    tokens, mask_ar, mask_loss, _ = preprocess_tokens(prefix, suffix, SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_loss": np.asarray(mask_loss),
    }


def validation_data_iterator():
  """Single iterator over validation examples."""
  for example in val_dataset.get_tfdata(ordered=True).as_numpy_iterator():
    image = Image.open(io.BytesIO(example["image"]))
    image = preprocess_image(image)

    prefix = "caption en"  # Could also be a different prefix per example.
    tokens, mask_ar, _, mask_input = preprocess_tokens(prefix, seqlen=SEQLEN)

    yield {
        "image": np.asarray(image),
        "text": np.asarray(tokens),
        "mask_ar": np.asarray(mask_ar),
        "mask_input": np.asarray(mask_input),
    }

In [12]:
# @title Inspect training examples.
def render_inline(image, resize=(128, 128)):
  """Convert image into inline html."""
  image = Image.fromarray(image)
  image.resize(resize)
  with io.BytesIO() as buffer:
    image.save(buffer, format='jpeg')
    image_b64 = str(base64.b64encode(buffer.getvalue()), "utf-8")
    return f"data:image/jpeg;base64,{image_b64}"

def render_example(image, caption):
  image = ((image + 1)/2 * 255).astype(np.uint8)  # [-1,1] -> [0, 255]
  return f"""
    <div style="display: inline-flex; align-items: center; justify-content: center;">
        <img style="width:128px; height:128px;" src="{render_inline(image, resize=(64,64))}" />
        <p style="width:256px; margin:10px; font-size:small;">{html.escape(caption)}</p>
    </div>
    """

html_out = ""
for idx, example in zip(range(8), train_data_iterator()):
  caption = postprocess_tokens(example["text"])  # detokenize model input.
  caption = caption[len("caption en\n"):]        # strip prefix
  html_out += render_example(example["image"], caption)

print("Training examples")
display(HTML(html_out))

Training examples


In [13]:
# @title Define the training step and evaluation loop.
#
# The main update_fn using AdamW optimizer via Optax.
#
import optax

# Optimizer will be defined in the training loop cell.

@functools.partial(jax.jit, donate_argnums=(0, 1))
def update_fn(params, opt_state, batch, step, rng):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    # Pass rngs={"dropout": rng} to ensure deterministic behavior if dropout is used
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True, rngs={"dropout": rng})
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens we normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using AdamW.
  # We only want to update trainable parameters.
  # Optax updates all parameters by default, so we mask the gradients.
  grads = jax.tree.map(lambda g, t: g if t else jnp.zeros_like(g), grads, trainable_mask)
  
  updates, new_opt_state = optimizer.update(grads, opt_state, params)
  new_params = optax.apply_updates(params, updates)

  return new_params, new_opt_state, loss

# Evaluation/inference loop.
def make_predictions(data_iterator, *, num_examples=None,
                     batch_size=4, seqlen=SEQLEN, sampler="greedy"):
  outputs = []
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        return outputs

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    # Append to html output.
    for example, response in zip(examples, responses):
      outputs.append((example["image"], response))
      if num_examples and len(outputs) >= num_examples:
        return outputs

In [14]:
# @title Run training loop.
#
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.
#

import time
import datetime
import random
import sys

BATCH_SIZE = 8
TRAIN_STEPS = 1000
LEARNING_RATE = 3e-5
PROFILE_STEP = -10
SEED = 42

# Set random seeds for reproducibility
tf.random.set_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Initialize JAX RNG key
rng = jax.random.PRNGKey(SEED)

# Create a timestamped directory for the trace and logs
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
trace_dir = f"profile/jax-trace_{timestamp}"
log_dir = f"logs/fit/{timestamp}"
summary_writer = tf.summary.create_file_writer(log_dir)

EVAL_STEPS = TRAIN_STEPS // 10

train_data_it = train_data_iterator(seed=SEED)

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

# Define optimizer and initialize state
optimizer = optax.adamw(learning_rate=sched_fn, weight_decay=1e-4)
opt_state = optimizer.init(params)

step_durations = []

for step in range(1, TRAIN_STEPS+1):
  if step == PROFILE_STEP:
    jax.profiler.start_trace(trace_dir)

  step_start = time.time()

  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  batch = big_vision.utils.reshard(batch, data_sharding)

  # Split RNG key for the current step
  rng, step_rng = jax.random.split(rng)

  # Training step and report training loss
  # Pass opt_state, step, and step_rng to update_fn
  params, opt_state, loss = update_fn(params, opt_state, batch, step, step_rng)

  # Block until ready to measure accurate GPU time
  jax.tree.leaves(params)[0].block_until_ready()
  loss = jax.device_get(loss)
  
  step_end = time.time()

  if step == PROFILE_STEP:
    jax.profiler.stop_trace()
    print(f"Profiled step {step}. Trace saved to {trace_dir}")

  step_duration = step_end - step_start

  if step > 5:
    step_durations.append(step_duration)

  learning_rate = sched_fn(step)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}   time: {step_duration:.2f}s")
  sys.stdout.flush()

  # Write metrics to TensorBoard
  with summary_writer.as_default():
    tf.summary.scalar('learning_rate', learning_rate, step=step)
    tf.summary.scalar('loss', loss, step=step)
    tf.summary.scalar('step_time', step_duration, step=step)
    
    # Calculate and log images per second per GPU
    # Avoid division by zero if step_duration is extremely small (though unlikely with blocking)
    if step_duration > 0:
      img_per_sec_per_gpu = (BATCH_SIZE / step_duration) / jax.device_count()
      tf.summary.scalar('img_per_sec_per_gpu', img_per_sec_per_gpu, step=step)

  if step == 1 or (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    html_out = ""
    # Batch size must be divisible by the number of devices (8).
    for image, caption in make_predictions(
        validation_data_iterator(), num_examples=8, batch_size=8):
      html_out += render_example(image, caption)
    display(HTML(html_out))

# Wait for the last step to finish to get accurate overall timing
jax.tree.leaves(params)[0].block_until_ready()

if step_durations:
  avg_step_time = sum(step_durations) / len(step_durations)
  print(f"Average step time (after first 5 steps): {avg_step_time:.2f}s")

MultiHeadDotProductAttention out shape: (8, 256, 1152)
MultiHeadDotProductAttention out shape: (8, 256, 1152)
RCCL version : 2.26.6-HEAD:ba59a6c
HIP version  : 7.0.51831-a577394db
ROCm version : 7.0.2.0-43-9428210
Hostname     : tw003
Librccl path : /opt/rocm/lib/librccl.so.1
RCCL version : 2.26.6-HEAD:ba59a6c
HIP version  : 7.0.51831-a577394db
ROCm version : 7.0.2.0-43-9428210
Hostname     : tw003
Librccl path : /opt/rocm/lib/librccl.so.1


2025-12-02 07:21:48.237206: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 4; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; groups=[[0,1,2,3,4,5,6,7]]; root_device=-1; num_local_participants=8; run_id=-1260338913` for 10 seconds and may be stuck. All 8 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
2025-12-02 07:21:48.237226: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 5; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; groups=[[0,1,2,3,4,5,6,7]]; root_device=-1; num_local_participants=8; run_id=-1260338913` for 10 seconds and may be stuck. All 8 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
2025-12-02 07:21:48.237262: E external/xla/xla/service/rendezvous.cc:90] This 

step:  1/1000   lr: 0.00000   loss: 3.3200   time: 39.69s
Model predictions at step 1
Model predictions at step 1
MultiHeadDotProductAttention out shape: (8, 256, 1152)
MultiHeadDotProductAttention out shape: (8, 256, 1152)
MultiHeadDotProductAttention out shape: (8, 256, 1152)
MultiHeadDotProductAttention out shape: (8, 256, 1152)


MultiHeadDotProductAttention out shape: (8, 256, 1152)
MultiHeadDotProductAttention out shape: (8, 256, 1152)
step:  2/1000   lr: 0.00000   loss: 3.5434   time: 4.27s
step:  2/1000   lr: 0.00000   loss: 3.5434   time: 4.27s
step:  3/1000   lr: 0.00000   loss: 3.0472   time: 0.92s
step:  3/1000   lr: 0.00000   loss: 3.0472   time: 0.92s
step:  4/1000   lr: 0.00000   loss: 2.6542   time: 0.99s
step:  4/1000   lr: 0.00000   loss: 2.6542   time: 0.99s
step:  5/1000   lr: 0.00000   loss: 2.4592   time: 0.72s
step:  5/1000   lr: 0.00000   loss: 2.4592   time: 0.72s
step:  6/1000   lr: 0.00000   loss: 2.3388   time: 0.76s
step:  6/1000   lr: 0.00000   loss: 2.3388   time: 0.76s
step:  7/1000   lr: 0.00000   loss: 2.2083   time: 0.81s
step:  7/1000   lr: 0.00000   loss: 2.2083   time: 0.81s
step:  8/1000   lr: 0.00000   loss: 2.0599   time: 0.74s
step:  8/1000   lr: 0.00000   loss: 2.0599   time: 0.74s
step:  9/1000   lr: 0.00000   loss: 1.9532   time: 0.96s
step:  9/1000   lr: 0.00000   loss:

step: 101/1000   lr: 0.00003   loss: 0.1407   time: 0.91s
step: 102/1000   lr: 0.00003   loss: 0.1372   time: 0.90s
step: 102/1000   lr: 0.00003   loss: 0.1372   time: 0.90s
step: 103/1000   lr: 0.00003   loss: 0.1046   time: 0.77s
step: 103/1000   lr: 0.00003   loss: 0.1046   time: 0.77s
step: 104/1000   lr: 0.00003   loss: 0.1152   time: 0.87s
step: 104/1000   lr: 0.00003   loss: 0.1152   time: 0.87s
step: 105/1000   lr: 0.00003   loss: 0.1678   time: 0.76s
step: 105/1000   lr: 0.00003   loss: 0.1678   time: 0.76s
step: 106/1000   lr: 0.00003   loss: 0.1519   time: 0.98s
step: 106/1000   lr: 0.00003   loss: 0.1519   time: 0.98s
step: 107/1000   lr: 0.00003   loss: 0.2302   time: 0.68s
step: 107/1000   lr: 0.00003   loss: 0.2302   time: 0.68s
step: 108/1000   lr: 0.00003   loss: 0.1739   time: 0.84s
step: 108/1000   lr: 0.00003   loss: 0.1739   time: 0.84s
step: 109/1000   lr: 0.00003   loss: 0.1180   time: 0.88s
step: 109/1000   lr: 0.00003   loss: 0.1180   time: 0.88s
step: 110/1000

step: 201/1000   lr: 0.00003   loss: 0.1042   time: 0.99s
step: 202/1000   lr: 0.00003   loss: 0.1422   time: 0.89s
step: 202/1000   lr: 0.00003   loss: 0.1422   time: 0.89s
step: 203/1000   lr: 0.00003   loss: 0.0951   time: 0.73s
step: 203/1000   lr: 0.00003   loss: 0.0951   time: 0.73s
step: 204/1000   lr: 0.00003   loss: 0.1006   time: 0.76s
step: 204/1000   lr: 0.00003   loss: 0.1006   time: 0.76s
step: 205/1000   lr: 0.00003   loss: 0.0773   time: 0.88s
step: 205/1000   lr: 0.00003   loss: 0.0773   time: 0.88s
step: 206/1000   lr: 0.00003   loss: 0.0834   time: 0.80s
step: 206/1000   lr: 0.00003   loss: 0.0834   time: 0.80s
step: 207/1000   lr: 0.00003   loss: 0.1783   time: 0.78s
step: 207/1000   lr: 0.00003   loss: 0.1783   time: 0.78s
step: 208/1000   lr: 0.00003   loss: 0.1120   time: 0.89s
step: 208/1000   lr: 0.00003   loss: 0.1120   time: 0.89s
step: 209/1000   lr: 0.00003   loss: 0.1004   time: 0.99s
step: 209/1000   lr: 0.00003   loss: 0.1004   time: 0.99s
step: 210/1000

step: 301/1000   lr: 0.00003   loss: 0.0894   time: 0.91s
step: 302/1000   lr: 0.00003   loss: 0.0752   time: 0.86s
step: 302/1000   lr: 0.00003   loss: 0.0752   time: 0.86s
step: 303/1000   lr: 0.00003   loss: 0.0896   time: 0.94s
step: 303/1000   lr: 0.00003   loss: 0.0896   time: 0.94s
step: 304/1000   lr: 0.00003   loss: 0.0996   time: 0.87s
step: 304/1000   lr: 0.00003   loss: 0.0996   time: 0.87s
step: 305/1000   lr: 0.00003   loss: 0.0868   time: 0.78s
step: 305/1000   lr: 0.00003   loss: 0.0868   time: 0.78s
step: 306/1000   lr: 0.00003   loss: 0.0757   time: 0.97s
step: 306/1000   lr: 0.00003   loss: 0.0757   time: 0.97s
step: 307/1000   lr: 0.00003   loss: 0.0697   time: 0.76s
step: 307/1000   lr: 0.00003   loss: 0.0697   time: 0.76s
step: 308/1000   lr: 0.00003   loss: 0.0711   time: 1.13s
step: 308/1000   lr: 0.00003   loss: 0.0711   time: 1.13s
step: 309/1000   lr: 0.00003   loss: 0.0828   time: 0.88s
step: 309/1000   lr: 0.00003   loss: 0.0828   time: 0.88s
step: 310/1000

step: 401/1000   lr: 0.00002   loss: 0.1131   time: 0.84s
step: 402/1000   lr: 0.00002   loss: 0.0701   time: 0.76s
step: 402/1000   lr: 0.00002   loss: 0.0701   time: 0.76s
step: 403/1000   lr: 0.00002   loss: 0.0869   time: 0.83s
step: 403/1000   lr: 0.00002   loss: 0.0869   time: 0.83s
step: 404/1000   lr: 0.00002   loss: 0.0719   time: 0.93s
step: 404/1000   lr: 0.00002   loss: 0.0719   time: 0.93s
step: 405/1000   lr: 0.00002   loss: 0.0726   time: 0.92s
step: 405/1000   lr: 0.00002   loss: 0.0726   time: 0.92s
step: 406/1000   lr: 0.00002   loss: 0.0695   time: 0.88s
step: 406/1000   lr: 0.00002   loss: 0.0695   time: 0.88s
step: 407/1000   lr: 0.00002   loss: 0.0674   time: 0.80s
step: 407/1000   lr: 0.00002   loss: 0.0674   time: 0.80s
step: 408/1000   lr: 0.00002   loss: 0.0687   time: 0.76s
step: 408/1000   lr: 0.00002   loss: 0.0687   time: 0.76s
step: 409/1000   lr: 0.00002   loss: 0.0877   time: 1.01s
step: 409/1000   lr: 0.00002   loss: 0.0877   time: 1.01s
step: 410/1000

step: 501/1000   lr: 0.00002   loss: 0.0706   time: 0.81s
step: 502/1000   lr: 0.00002   loss: 0.0788   time: 0.93s
step: 502/1000   lr: 0.00002   loss: 0.0788   time: 0.93s
step: 503/1000   lr: 0.00002   loss: 0.0727   time: 0.80s
step: 503/1000   lr: 0.00002   loss: 0.0727   time: 0.80s
step: 504/1000   lr: 0.00002   loss: 0.0704   time: 0.77s
step: 504/1000   lr: 0.00002   loss: 0.0704   time: 0.77s
step: 505/1000   lr: 0.00002   loss: 0.0772   time: 0.94s
step: 505/1000   lr: 0.00002   loss: 0.0772   time: 0.94s
step: 506/1000   lr: 0.00002   loss: 0.0819   time: 0.75s
step: 506/1000   lr: 0.00002   loss: 0.0819   time: 0.75s
step: 507/1000   lr: 0.00002   loss: 0.0725   time: 0.91s
step: 507/1000   lr: 0.00002   loss: 0.0725   time: 0.91s
step: 508/1000   lr: 0.00002   loss: 0.0687   time: 0.79s
step: 508/1000   lr: 0.00002   loss: 0.0687   time: 0.79s
step: 509/1000   lr: 0.00002   loss: 0.0729   time: 0.83s
step: 509/1000   lr: 0.00002   loss: 0.0729   time: 0.83s
step: 510/1000

step: 601/1000   lr: 0.00001   loss: 0.0769   time: 0.86s
step: 602/1000   lr: 0.00001   loss: 0.0735   time: 0.78s
step: 602/1000   lr: 0.00001   loss: 0.0735   time: 0.78s
step: 603/1000   lr: 0.00001   loss: 0.0818   time: 0.86s
step: 603/1000   lr: 0.00001   loss: 0.0818   time: 0.86s
step: 604/1000   lr: 0.00001   loss: 0.0635   time: 0.87s
step: 604/1000   lr: 0.00001   loss: 0.0635   time: 0.87s
step: 605/1000   lr: 0.00001   loss: 0.0703   time: 0.81s
step: 605/1000   lr: 0.00001   loss: 0.0703   time: 0.81s
step: 606/1000   lr: 0.00001   loss: 0.0746   time: 0.83s
step: 606/1000   lr: 0.00001   loss: 0.0746   time: 0.83s
step: 607/1000   lr: 0.00001   loss: 0.0819   time: 1.02s
step: 607/1000   lr: 0.00001   loss: 0.0819   time: 1.02s
step: 608/1000   lr: 0.00001   loss: 0.0889   time: 0.88s
step: 608/1000   lr: 0.00001   loss: 0.0889   time: 0.88s
step: 609/1000   lr: 0.00001   loss: 0.0635   time: 0.94s
step: 609/1000   lr: 0.00001   loss: 0.0635   time: 0.94s
step: 610/1000

step: 701/1000   lr: 0.00001   loss: 0.0663   time: 0.79s
step: 702/1000   lr: 0.00001   loss: 0.0695   time: 0.86s
step: 702/1000   lr: 0.00001   loss: 0.0695   time: 0.86s
step: 703/1000   lr: 0.00001   loss: 0.0633   time: 0.79s
step: 703/1000   lr: 0.00001   loss: 0.0633   time: 0.79s
step: 704/1000   lr: 0.00001   loss: 0.0636   time: 0.96s
step: 704/1000   lr: 0.00001   loss: 0.0636   time: 0.96s
step: 705/1000   lr: 0.00001   loss: 0.0785   time: 0.89s
step: 705/1000   lr: 0.00001   loss: 0.0785   time: 0.89s
step: 706/1000   lr: 0.00001   loss: 0.0677   time: 0.86s
step: 706/1000   lr: 0.00001   loss: 0.0677   time: 0.86s
step: 707/1000   lr: 0.00001   loss: 0.0727   time: 0.89s
step: 707/1000   lr: 0.00001   loss: 0.0727   time: 0.89s
step: 708/1000   lr: 0.00001   loss: 0.0663   time: 0.96s
step: 708/1000   lr: 0.00001   loss: 0.0663   time: 0.96s
step: 709/1000   lr: 0.00001   loss: 0.0865   time: 0.90s
step: 709/1000   lr: 0.00001   loss: 0.0865   time: 0.90s
step: 710/1000

step: 801/1000   lr: 0.00000   loss: 0.0599   time: 0.75s
step: 802/1000   lr: 0.00000   loss: 0.0701   time: 0.84s
step: 802/1000   lr: 0.00000   loss: 0.0701   time: 0.84s
step: 803/1000   lr: 0.00000   loss: 0.0643   time: 0.79s
step: 803/1000   lr: 0.00000   loss: 0.0643   time: 0.79s
step: 804/1000   lr: 0.00000   loss: 0.0530   time: 0.96s
step: 804/1000   lr: 0.00000   loss: 0.0530   time: 0.96s
step: 805/1000   lr: 0.00000   loss: 0.0810   time: 1.02s
step: 805/1000   lr: 0.00000   loss: 0.0810   time: 1.02s
step: 806/1000   lr: 0.00000   loss: 0.0668   time: 0.94s
step: 806/1000   lr: 0.00000   loss: 0.0668   time: 0.94s
step: 807/1000   lr: 0.00000   loss: 0.0669   time: 0.76s
step: 807/1000   lr: 0.00000   loss: 0.0669   time: 0.76s
step: 808/1000   lr: 0.00000   loss: 0.0571   time: 0.76s
step: 808/1000   lr: 0.00000   loss: 0.0571   time: 0.76s
step: 809/1000   lr: 0.00000   loss: 0.0790   time: 0.87s
step: 809/1000   lr: 0.00000   loss: 0.0790   time: 0.87s
step: 810/1000

step: 901/1000   lr: 0.00000   loss: 0.0704   time: 0.78s
step: 902/1000   lr: 0.00000   loss: 0.0588   time: 0.80s
step: 902/1000   lr: 0.00000   loss: 0.0588   time: 0.80s
step: 903/1000   lr: 0.00000   loss: 0.0621   time: 0.82s
step: 903/1000   lr: 0.00000   loss: 0.0621   time: 0.82s
step: 904/1000   lr: 0.00000   loss: 0.0577   time: 0.85s
step: 904/1000   lr: 0.00000   loss: 0.0577   time: 0.85s
step: 905/1000   lr: 0.00000   loss: 0.0553   time: 0.90s
step: 905/1000   lr: 0.00000   loss: 0.0553   time: 0.90s
step: 906/1000   lr: 0.00000   loss: 0.0670   time: 0.89s
step: 906/1000   lr: 0.00000   loss: 0.0670   time: 0.89s
step: 907/1000   lr: 0.00000   loss: 0.0659   time: 0.95s
step: 907/1000   lr: 0.00000   loss: 0.0659   time: 0.95s
step: 908/1000   lr: 0.00000   loss: 0.0628   time: 0.85s
step: 908/1000   lr: 0.00000   loss: 0.0628   time: 0.85s
step: 909/1000   lr: 0.00000   loss: 0.0541   time: 0.88s
step: 909/1000   lr: 0.00000   loss: 0.0541   time: 0.88s
step: 910/1000

Average step time (after first 5 steps): 0.85s


In [15]:
%%time
# @title Evaluate the model on all examples.
#
# The validation data consists of 10 images in a different domain than training
# data.

print("Model predictions")
html_out = ""
# Batch size must be divisible by the number of devices (8).
for image, caption in make_predictions(validation_data_iterator(), batch_size=8):
  html_out += render_example(image, caption)
display(HTML(html_out))

Model predictions


CPU times: user 24.4 s, sys: 1.69 s, total: 26.1 s
Wall time: 8.83 s


# Save the final checkpoint

In [16]:
# def npsave(pytree, path):
#   names_and_vals, _ = big_vision.utils.tree_flatten_with_names(pytree)
#   with open(path, "wb") as f:
#     np.savez(f, **{k:v for k, v in names_and_vals})

# # Takes around 4 minutes
# npsave(params, 'my-custom-paligemma-ckpt.npz')

In [17]:
import inspect
import big_vision.models.vit as vit_module

print(f"Loaded file: {vit_module.__file__}")
print("-" * 40)
print("Source code of Encoder1DBlock currently in memory:")
print(inspect.getsource(vit_module.Encoder1DBlock))

Loaded file: /home/dougljia@amd.com/big_vision/big_vision/configs/proj/paligemma/big_vision_repo/big_vision/models/vit.py
----------------------------------------
Source code of Encoder1DBlock currently in memory:
class Encoder1DBlock(nn.Module):
  """Single transformer encoder block (MHSA + MLP)."""
  mlp_dim: Optional[int] = None  # Defaults to 4x input dim
  num_heads: int = 12
  dropout: float = 0.0
  dtype_mm: str = "float32"

  @nn.compact
  def __call__(self, x, deterministic=True):
    out = {}
    x = nn.with_logical_constraint(x, ("act_batch", "act_len", "act_emb"))
    y = nn.LayerNorm()(x)
    y = out["sa"] = MultiHeadDotProductAttention(
        num_heads=self.num_heads,
        kernel_init=nn.initializers.xavier_uniform(),
        deterministic=deterministic,
        dtype=self.dtype_mm,
    )(y, y)
    y = nn.with_logical_constraint(y, ("act_batch", "act_len", "act_emb"))
    y = nn.Dropout(rate=self.dropout)(y, deterministic)
    x = out["+sa"] = x + y

    y = nn.Layer