# Ink Interactions @ AMLD

# Section 1 - Introduction
In this colab you will be able to interact with a "webpage"-like document using ink. To this end, we will experiment with different models and approaches on the following task:

Given an image of the webpage with a rendered gesture on top of it, identify:
- The **type of the gesture**
- The **target of the gesture** (i.e. its bounding box)
- Additional **text instructions**

By constraining this problem on a specific set of gestures with pre-defined
behavior, we can interact/update/change the document based on the
predicted attributes of a gesture.

We will work with the following gestures:
- **insert**: to insert text at a given region
- **select**: to highlight parts of the webpage
- **question**: to get clarifications about a part of the webpage
- **delete**: to delete parts of the webpage
- **crop**: to crop a region of the webpage
- **underline**: to highlight parts of the webpage
- **instruct**: to interact with the content (image/text) via instruction

<img src="https://storage.googleapis.com/amld_workshop_natural_interactions_with_llms/gesture_classes.png" height="400"/>

------

Note that a given type of gesture can have multiple variants. For example, a `delete` operation can be done by drawing a squiggly line, by striking out the text with a single horizontal stroke or by crossing out the text. The pictures above only show a single variant per class. If you are interested, you can look at the TFRecord dataset contained in `data/gestures` (see below, the part about downloading the datasets for this colab).

As described above, the original training set for the model consists of images of webpages taken from Wikipedia with some handwritten gesture rendered on top. Since images are static, we converted them into a very (very) crude document structure. This way, we will be able to simulate how an "interaction" with the model looks like. A document is essentially a collection of words, text lines and paragraphs represented with a location and some size in JSON format:


```
"elements": [
  {
    "id": 1,
    "parent_id": 0,
    "class_name": "textline",
    "bbox": {"left": 23, "top": 35, "right": 384, "bottom": 63},
    "children_ids": [2, 3, 4, …]
  },
  {
    "id": 2,
    "parent_id": 1,
    "class_name": "word",
    "text": "foo",
    "bbox": {"left": 23, "top": 35, "right": 185, "bottom": 63},
  },
  …
]
```

We predefined for you a series of simple document "edits" (available in the `document_editing` module) that can be triggered using the output of the model. These are mostly intended for illustrative purposes and to play with the model in the [interactive ink canvas](#scrollTo=f-9JG2yJnopD). In the figure below for example, we used for example the `insert` tool to add some text to the paragraph. In these visualizations, each individual element is represented by its bounding box (light blue for words, blue for text lines an orange for paragraphs) and text content.

You will notice that they look quite different from the data that the model has seen during training. However, in practice the model is quite robust to small differences in rendering.

<img src="https://storage.googleapis.com/amld_workshop_natural_interactions_with_llms/document_editing_illustration.png" height="400"/>

As part of this workshop and in addition to using the model to perform known tasks, you will:
- Experiment with in-context learning through zero-/few-shot approaches with Gemini to solve the tasks without any fine-tuning
- Attempt to define a brand new task with a novel gesture class and fine-tune the pre-trained model on it

## Setting up code repositories and datasets.

This section imports all the code and downloads all the data required for the workshop. You can explore the assets used by this notebook by clicking on the navigation bar on the left. All datasets are located under `/content/data` while model files are located under `/content/models`. This will take a few minutes ⏳!

**Important note**: you may need to run `!gcloud init` in a cell the first time you use the `gcloud` command.

⏩ You do not need to interact with the code in the **following cells**, you may therefore **keep them collapsed** to keep the colab as readable as possible.

### Code setup
Import code and clone repositories required for the workshop.

In [None]:
import copy
import functools
import glob
import io
import json
import os
import random
import shutil
import sys
import warnings

from IPython.display import HTML, Javascript, display
from PIL import Image, ImageDraw
import tensorflow as tf
from tqdm import tqdm
import numpy as np

# The T4 runtime is tight on memory to finetune this model. Preallocate
# all memory ahead of time to avoid OOM'ing due to fragmentation.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"

In [None]:
!git clone https://github.com/google-deepmind/amld_workshop_natural_interactions_with_llms.git
!git clone --branch=main --depth=1 https://github.com/google-research/big_vision big_vision_repository

# Install libraries needed for cairo.
!apt-get install -q libcairo2-dev libjpeg-dev libgif-dev

# Install missing dependencies. Assume jax~=0.4.25 with GPU available.
!pip3 install -q "crcmod" "overrides" "ml_collections" "einops~=0.7" "sentencepiece" "jiwer" "pycairo"

sys.path.append('amld_workshop_natural_interactions_with_llms')
sys.path.append("big_vision_repository")

import ml_collections
import sentencepiece

import big_vision
import jax
import big_vision.models.vit
from big_vision.models.proj.paligemma import gemma_bv
from big_vision.models.proj.paligemma import paligemma
from big_vision.trainers.proj.paligemma import predict_fns
import big_vision.datasets.jsonl
import big_vision.utils
import big_vision.sharding
import jax.numpy as jnp

Re-execute this cell if you make changes to the code under `/content/amld_workshop_natural_interactions_with_llms`

In [None]:
%load_ext autoreload
%autoreload 2

import arrow_ink_tools
import data_processing
import document_editing
import in_context_learning
import metrics
import model_api
import notebook_canvas
import paligemma_gesture_preparation
import paligemma_tools
import rendering
import sampling

### Data setup
This will download all the workshop data to the local filesystem.

In [None]:
INK_GCP_BUCKET_URL = 'gs://amld_workshop_natural_interactions_with_llms'
BIG_VISION_BUCKET_URL = 'gs://big_vision'

PAGES_DIR = 'data/wikipedia_public'
INK_VALID_PATH = 'data/gestures/ink_gestures_valid.json'
INK_TRAIN_PATH = 'data/gestures/ink_gestures_train.json'
TF_VALID_PATH = 'data/gestures/annotation_gestures_valid.tfrecord'
TF_TRAIN_PATH = 'data/gestures/annotation_gestures_train.tfrecord'
PALIGEMMA_MODEL_PATH = 'models/paligemma_ft_digink_448.b16.npz'
PALIGEMMA_TOKENIZER_PATH = 'models/paligemma_tokenizer.model'


def download_workshop_data(file_path: str, bucket=INK_GCP_BUCKET_URL):
  file_name = os.path.basename(file_path)
  if os.path.exists(file_path):
    return

  print(f'Downloading {file_name} into {file_path}.')
  if file_name.endswith('.zip'):
    !gsutil cp {os.path.join(bucket, file_name)} {file_path}
    !unzip -q {file_path} -d {os.path.dirname(file_path)}
  else:
    !gsutil cp {os.path.join(bucket, file_name)} {file_path}
  print()

download_workshop_data(os.path.join(PAGES_DIR, 'wikipedia_public.zip'))
download_workshop_data(INK_TRAIN_PATH)
download_workshop_data(INK_VALID_PATH)
download_workshop_data(TF_TRAIN_PATH)
download_workshop_data(TF_VALID_PATH)
download_workshop_data(PALIGEMMA_MODEL_PATH)
download_workshop_data(PALIGEMMA_TOKENIZER_PATH, bucket=BIG_VISION_BUCKET_URL)

### Utility functions
This section defines additional utility functions to load, manipulate and visualize the data that we just downloaded and that we will be using during the workshop.

In [None]:
_IMAGE_RESOLUTION = 448

def convert_bbox(
    left: float, top: float, right: float, bottom: float
) -> document_editing.BoundingBox:
  """Convert a bounding box from a list of integers to a dataclass instance."""
  return document_editing.BoundingBox(
      top=top,
      left=left,
      bottom=bottom,
      right=right,
  )

_EMPTY_STRING = tf.convert_to_tensor("", dtype="string").numpy()
_EMPTY_BOUNDING_BOX = np.array([0.0, 0.0, 0.0, 0.0], dtype="float32")
_EMPTY_IMAGE = tf.io.encode_png(tf.constant([[[0]]], dtype="uint8")).numpy()
FEATURE_SPEC = {
    "ink_hash": tf.io.FixedLenFeature(
        [], dtype=tf.string, default_value=_EMPTY_STRING
    ),
    "example_id": tf.io.FixedLenFeature(
        [], dtype=tf.string, default_value=_EMPTY_STRING
    ),
    "annotation_bbox": tf.io.FixedLenFeature(
        [4], dtype=tf.float32, default_value=_EMPTY_BOUNDING_BOX
    ),
    "composition_bbox":  tf.io.FixedLenFeature(
        [4], dtype=tf.float32, default_value=_EMPTY_BOUNDING_BOX
    ),
    "label": tf.io.FixedLenFeature(
        [], dtype=tf.string, default_value=_EMPTY_STRING
    ),
    "image/encoded": tf.io.FixedLenFeature(
        [], dtype=tf.string, default_value=_EMPTY_IMAGE
    ),
    "image/encoded_original": tf.io.FixedLenFeature(
        [], dtype=tf.string, default_value=_EMPTY_IMAGE
    ),
    "annotation_text": tf.io.FixedLenFeature(
        [], dtype=tf.string, default_value=_EMPTY_STRING
    ),
    "writing_guide": tf.io.FixedLenFeature(
        [4], dtype=tf.float32, default_value=_EMPTY_BOUNDING_BOX
    ),
}

def parse_single_example(elem):
  return tf.io.parse_single_example(elem, FEATURE_SPEC)

def read_examples(file_path):
  """Reads dataset examples from an sstable file."""
  examples = {}

  for sample in tf.data.TFRecordDataset([file_path]).map(parse_single_example):
    classname = sample['label'].numpy().decode()
    bbox = sample['annotation_bbox'].numpy()
    bbox = [
        int(bbox[0] * paligemma_tools.LOCATION_TOKENS_RANGE_MAX / _IMAGE_RESOLUTION),
        int(bbox[1] * paligemma_tools.LOCATION_TOKENS_RANGE_MAX / _IMAGE_RESOLUTION),
        int(bbox[2] * paligemma_tools.LOCATION_TOKENS_RANGE_MAX / _IMAGE_RESOLUTION),
        int(bbox[3] * paligemma_tools.LOCATION_TOKENS_RANGE_MAX / _IMAGE_RESOLUTION),
    ]

    image = sample['image/encoded'].numpy()
    original_image = sample['image/encoded_original'].numpy()
    text = sample['annotation_text'].numpy().decode()
    label = (
        f'{classname} '
        f'{bbox[1]:.0f} {bbox[0]:.0f} '
        f'{bbox[3]:.0f} {bbox[2]:.0f} {text}'.strip()
    )
    ink_hash = sample['ink_hash'].numpy().decode()
    examples[ink_hash] = {
        'ink_hash': ink_hash,
        'bbox': convert_bbox(*bbox),
        'classname': classname,
        'composition_bbox': convert_bbox(*sample['composition_bbox'].numpy()),
        'image': Image.open(io.BytesIO(image)),
        'label': label,
        'original_image': Image.open(io.BytesIO(original_image)),
        'page_id': sample['example_id'].numpy().decode(),
        'text': text,
        'writing_guide': convert_bbox(*sample['writing_guide'].numpy()),
    }

  return examples

In [None]:
with open(INK_VALID_PATH, 'r') as f:
  ink_by_hash = {}

  for ink_hash, ink_data in json.load(f).items():
    ink_by_hash[ink_hash] = document_editing.Ink(
        strokes=[
            document_editing.Stroke(xs=stroke['xs'], ys=stroke['ys'])
            for stroke in ink_data['strokes']
        ]
    )

## Load all inks and documents
This will read all the training and validation examples in [TFExample](https://www.tensorflow.org/api_docs/python/tf/train/Example?) format and prepare `document_editing.Page`'s that represents documents with their elements and images.

⏩ You do not need to interact with the code in the **following cells**, you may therefore **keep them collapsed** to keep the colab as readable as possible.

In [None]:
train_examples = read_examples(TF_TRAIN_PATH)
train_page_ids = [example['page_id'] for example in train_examples.values()]
valid_examples = read_examples(TF_VALID_PATH)
valid_page_ids = [example['page_id'] for example in valid_examples.values()]

valid_pages_data = [document_editing.load_page(PAGES_DIR, page_id) for page_id in tqdm(valid_page_ids)]
valid_pages = dict(zip(valid_page_ids, valid_pages_data))

## Load the pre-trained PaLIGemma 2 model for gesture recognition

This part will initialize a [PaLIGemma](https://ai.google.dev/gemma/docs/paligemma) (VLM) model trained on gesture recognition. This checkpoint was obtained by taking a standard PaLIGemma pre-trained model and fine-tuning it on a ~50/50 mixture of two tasks:

1. A gesture recognition task, where the model has to interpret a gesture, given the image of an annotated document.
1. A text detection task, where the model has to output a bounding box, given some text that appears on the image.


For gesture recognition, the target is a string formatted like:

```
<type of gesture> <the gesture's target bounding box> <detected text of the gesture>
```

For example:

<img src="https://storage.googleapis.com/amld_workshop_natural_interactions_with_llms/gesture_recognition_task.png" height="300"/>

For the text detection task, the target is just the bounding box but the input is an image and some text appearing in the image.

<img src="https://storage.googleapis.com/amld_workshop_natural_interactions_with_llms/text_detection_task.png" height="300"/>

When training this model, we noticed that adding the text detection task to the mixture turned out to be one of the most important factors in getting it to output accurate bounding boxes around text elements. This also made the model more robust in detecting the target of the gesture, despite users drawing them imprecisely (for example, by including or omitting neighboring characters in circlings).

Also, you might have noticed that the background in the gesture recognition task appears slightly transparent. This is just a simple trick that helps to separate the gesture overlay from the rest of the scene, making it easier for the model to "see". This transparency is safe for inference, as the document and gesture overlay are distinct rendering layers that we fully control.


This code is based on the example colab from the creators of PaLIGemma https://colab.research.google.com/github/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/finetune_paligemma.ipynb

**Important** You can reload the model by rerunning all cells in this section.

⏩ This code will be discussed in the subsequent section on fine-tuning ↩️ . Loading the model from disk will take a few minutes ⏳!

In [None]:
# @title Initialize the PaLIGemma model.

# Don't let TF use the GPU or TPUs
#tf.config.set_visible_devices([], "GPU")
#tf.config.set_visible_devices([], "TPU")

backend = jax.extend.backend.get_backend()
print(f"JAX version:  {jax.__version__}")
print(f"JAX platform: {backend.platform}")
print(f"JAX devices:  {jax.device_count()}")

# Use these for PaliGemma-2 3B 448px²
LLM_VARIANT = "gemma2_2b"

model_config = ml_collections.FrozenConfigDict({
    "llm": {"vocab_size": 257_152, "variant": LLM_VARIANT, "final_logits_softcap": 0.0},
    "img": {"variant": "So400m/14", "pool_type": "none", "scan": True, "dtype_mm": "float16"}
})

model = paligemma.Model(**model_config)
tokenizer = sentencepiece.SentencePieceProcessor(PALIGEMMA_TOKENIZER_PATH)
paligemma_tokenizer = paligemma_tools.PaliGemmaTokenizer(tokenizer)

# Load params - this can take up to 1 minute in T4 colabs.
params = paligemma.load(None, PALIGEMMA_MODEL_PATH, model_config)

# Define `decode` function to sample outputs from the model.
decode_fn = predict_fns.get_all(model)['decode']
decode = functools.partial(decode_fn, devices=jax.devices(), eos_token=tokenizer.eos_id())

This section sets up the model for subsequent partial fine-tuning. PaLIGemma has two input modalities -- **image** and **text**. Let's look at model's parameters

In [None]:
# @title Print parameters of the model.
def parameter_overview(params):
  for path, arr in big_vision.utils.tree_flatten_with_names(params)[0]:
    print(f"{path:80s} {str(arr.shape):22s} {arr.dtype}")

print(" == Model params == ")
parameter_overview(params)

It is a common practice to target attention matrices $Q, K, V, O$ for fine-tuning [link](https://medium.com/@danushidk507/fine-tuning-with-lora-and-qlora-enhancing-efficiency-in-neural-network-adaptation-8b4d1473274b). You can choose to train those parameters only in Gemma (language) part or in SigLIP (image) and Gemma (language) together.


In [None]:
# @title Pick trainable parameters
#
# To keep HBM usage low and fit in a T4 GPU (16GB HBM) we opt to only finetune
# a part of the parameters. Additionally we keep the frozen params in float16
# and cast trainable to float32.

# Create a pytree mask of the trainable params.
def is_trainable_param_image_language(name, param):  # pylint: disable=unused-argument
  if name.startswith("img/Transformer/encoderblock/MultiHeadDotProductAttention_0"):  return True
  if name.startswith('llm/layers/attn'):    return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")

def is_trainable_param_language(name, param):  # pylint: disable=unused-argument
  if name.startswith("llm/layers/attn"):  return True
  if name.startswith("llm/"):              return False
  if name.startswith("img/"):              return False
  raise ValueError(f"Unexpected param name {name}")

finetuning = "image+language" # @param ["image+language", "language"]
is_trainable_param = is_trainable_param_language if finetuning == "language" else is_trainable_param_image_language
trainable_mask = big_vision.utils.tree_map_with_names(is_trainable_param, params)

# If more than one device is available (e.g. multiple GPUs) the parameters can
# be sharded across them to reduce HBM usage per device.
mesh = jax.sharding.Mesh(jax.devices(), ("data"))

data_sharding = jax.sharding.NamedSharding(
    mesh, jax.sharding.PartitionSpec("data"))

params_sharding = big_vision.sharding.infer_sharding(
    params, strategy=[('.*', 'fsdp(axis="data")')], mesh=mesh)

# Yes: Some donated buffers are not usable.
warnings.filterwarnings(
    "ignore", message="Some donated buffers were not usable")

@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))
def maybe_cast_to_f32(params, trainable):
  # Cast others to float16, since some GPUs don't support bf16.
  return jax.tree.map(lambda p, m: p.astype(jnp.float32)
                      if m else p.astype(jnp.float16),
                      params, trainable)

Let's take a look at the parameter mask for fine-tuning.

In [None]:
# @title Trainable parameters mask.
trainable_mask

In [None]:
# @title Move params to GPU/TPU memory.

# Loading all params in simultaneous - albeit much faster and more succinct -
# requires more RAM than the T4 colab runtimes have by default (12GB RAM).
# Instead we do it param by param.
params, treedef = jax.tree.flatten(params)
sharding_leaves = jax.tree.leaves(params_sharding)
trainable_leaves = jax.tree.leaves(trainable_mask)
for idx, (sharding, trainable) in enumerate(zip(sharding_leaves, trainable_leaves)):
  #params[idx] = big_vision.utils.reshard(params[idx], sharding)
  params[idx] = maybe_cast_to_f32(params[idx], False) # trainable
  params[idx].block_until_ready()
params = jax.tree.unflatten(treedef, params)

In [None]:
# @title Define the Evaluation/Inference loop.
def make_predictions(data_iterator,
                     postprocess_tokens,
                     num_examples=None,
                     batch_size=4, seqlen=paligemma_gesture_preparation._SEQLEN,
                     sampler="greedy"):
  num_predicted = 0
  while True:
    # Construct a list of examples in the batch.
    examples = []
    try:
      for _ in range(batch_size):
        examples.append(next(data_iterator))
        examples[-1]["_mask"] = np.array(True)  # Indicates true example.
    except StopIteration:
      if len(examples) == 0:
        break

    # Not enough examples to complete a batch. Pad by repeating last example.
    while len(examples) % batch_size:
      examples.append(dict(examples[-1]))
      examples[-1]["_mask"] = np.array(False)  # Indicates padding example.

    # Convert list of examples into a dict of np.arrays and load onto devices.
    batch = jax.tree.map(lambda *x: np.stack(x), *examples)
    batch = big_vision.utils.reshard(batch, data_sharding)

    # Make model predictions
    tokens = decode({"params": params}, batch=batch,
                    max_decode_len=seqlen, sampler=sampler)

    # Fetch model predictions to device and detokenize.
    tokens, mask = jax.device_get((tokens, batch["_mask"]))
    tokens = tokens[mask]  # remove padding examples.
    responses = [postprocess_tokens(t) for t in tokens]

    if num_examples and num_predicted + len(responses) >= num_examples:
      responses = responses[:num_examples - num_predicted]
      yield from responses
      break
    yield from responses
    num_predicted += len(responses)

## Apply the model on the validation dataset

Now that the model is loaded in memory, let's apply it on the validation dataset and see what it does. The next few sections define some helper functions to load the data, pass it to the model and visualize the output.

In [None]:
def valid_data_iterator(examples_to_display):
  for _, example in examples_to_display:
    yield paligemma_gesture_preparation.prepare_inference_input(
        paligemma_tokenizer, example['image']
    )

In [None]:
examples_to_display = [
    (k, v)
    for k, v in valid_examples.items()
    if not v['label'].startswith('instruct') and not v['label'].startswith('question')
]
random.seed(123)
random.shuffle(examples_to_display)

In [None]:
predictions = []
valid_data_iter = valid_data_iterator(examples_to_display)
for pred in tqdm(make_predictions(valid_data_iter,
                     postprocess_tokens=paligemma_tokenizer.postprocess_tokens,
                     num_examples=16)):
  predictions.append(data_processing.DocumentEditingLabel.from_string(pred))

In [None]:
correct_class = []

for (page_id, example), prediction in zip(examples_to_display, predictions):
  if prediction is None:
    continue
  correct_class.append(prediction.gesture == example['classname'])

print(f'Gesture classification accuracy: {np.mean(correct_class)}')
for i, pred in enumerate(predictions):
  print('prediction:', pred)
  print('target:    ', examples_to_display[i][1]['label'])
  print('-----------')

## Visualize the predictions on the validation dataset

In [None]:
# @title Define visualization helper functions.

def format_label(prefix: str, label: str, color: str) -> str:
  parts = label.split(' ')
  class_name = parts[0]
  bbox = parts[1:5]
  text = parts[5:]
  return (
      f'<tt>{prefix} '
      f'<b>{class_name}</b> <span'
      f' style="background-color:{color}">{" ".join(bbox)}</span>'
      f' <span>{" ".join(text)}</span></tt>'
  )

def format_prediction(prefix: str, prediction: data_processing.DocumentEditingLabel, color: str) -> str:
  return (
      f'<tt>{prefix} '
      f'<b>{prediction.gesture}</b> <span'
      f' style="background-color:{color}">{prediction.bbox.top:.0f} '
      f'{prediction.bbox.left:.0f} {prediction.bbox.bottom:.0f} '
      f'{prediction.bbox.right:.0f}</span>'
      f' <span>{prediction.text}</span></tt>'
  )

def table_row(contents: list[list[str]]) -> str:
  cell_contents = ['<br/>'.join(content) for content in contents]
  cells = [
      '<td style="border: 1px solid lightgray; text-align:'
      f' center">{cell_content}</td>'
      for cell_content in cell_contents
  ]
  return f'<tr>{"".join(cells)}</tr>'

In [None]:
html = ["""
<table style="width: 1500px; border-collapse: collapse; border: 1px solid lightgray;">
<thead>
  <tr>
    <th>Page</th>
    <th>Model output</th>
    <th>Document before</th>
    <th>Document after</th>
  </tr>
</thead>
<tbody>
"""]

for (_, example), prediction_full in tqdm(zip(examples_to_display, predictions)):
  if prediction_full is None:
    continue

  html_row_contents = []
  page = valid_pages[example['page_id']]

  # Display the page ID with a rendering of the gesture.
  html_row_contents.append([
      f'<h4>{page_id}</h4>',
      rendering.to_html_image(example['original_image'], width=300),
  ])

  # Ground truth and prediction strings.
  html_row_contents.append([
      format_label('Label:', example['label'], 'rgba(0, 255, 0, 0.3)'),
      format_prediction('Pred :', prediction_full, 'rgba(255, 0, 255, 0.3)'),
  ])

  # Document before and after the edit.
  page_copy = copy.deepcopy(page)

  bbox = example['bbox']
  composition_bbox = example['composition_bbox']
  writing_guide_bbox = example['writing_guide']

  bbox = rendering.bbox_to_image_space(bbox, composition_bbox)

  prediction = prediction_full.bbox
  prediction = rendering.bbox_to_image_space(prediction, composition_bbox)
  prediction_classname = prediction_full.gesture
  prediction_text = prediction_full.text

  page_copy.edit(
      edit_name=prediction_classname,
      edit_bbox=prediction,
      text=prediction_text,
  )

  # We compute first the "after" state, which returns an area of interest we can
  # crop around for better readability in the output table.
  rendering_after = rendering.render_document(
      page_copy,
      overlay_bboxes={'': composition_bbox},
      crop_area=True,
  )
  html_image_after = rendering.to_html_image(rendering_after.image, width=400)

  rendering_before = rendering.render_document(
      page,
      overlay_bboxes={'lime': bbox, 'fuchsia': prediction},
      ink=ink_by_hash[example['ink_hash']],
      crop_area=rendering_after.area_of_interest,
  )
  html_image_before = rendering.to_html_image(rendering_before.image, width=400)

  html_row_contents.append([html_image_before])
  html_row_contents.append([html_image_after])
  html.append(table_row(html_row_contents))

html.append('</tbody></table>')
display(HTML(''.join(html)))

## Interactive Ink Canvas
In this section you can load one document from the dataset and play with the model. Draw a gesture with the mouse and click the `interpret` button to make a model call. If everything goes well, you should be able to see on the right the image provided as input to the model and, overlaid on top of it, the detected target.


✅ Feel free to come back to this part throughout the colab (in particular in later stages after fine-tuning to experiment with the predictions of the model).

**After fine-tuning:** Parameter margin is used for cropping the image around the ink. You probably want to keep it the same as in synthetic data generation.

In [None]:
# @title Define prediction helper function.

margin = 40 #@param {type: "integer"}
def canvas_predict_fn(ink: document_editing.Ink, image: Image.Image):
  # Prepare a square area around the gesture.
  rendered_ink = rendering.render_ink_on_image(ink, image, add_semi_transparent_overlay=True)

  ink_bbox = ink.get_bbox()
  size = max(ink_bbox.width, ink_bbox.height) + margin
  gesture_area = document_editing.BoundingBox(
      top=ink_bbox.center.y - size // 2,
      left=ink_bbox.center.x - size // 2,
      bottom=ink_bbox.center.y + size // 2,
      right=ink_bbox.center.x + size // 2,
  )
  model_input = rendered_ink.crop((
      gesture_area.left,
      gesture_area.top,
      gesture_area.right,
      gesture_area.bottom,
  )).resize((_IMAGE_RESOLUTION, _IMAGE_RESOLUTION))

  # Prepare the input for the model.
  composition = Image.new("RGBA", model_input.size, "white")
  composition.paste(model_input, mask=model_input)
  inference_input = paligemma_gesture_preparation.prepare_inference_input(
      paligemma_tokenizer, image=composition
  )

  notebook_canvas.set_debug_output('', rendering.to_data_url(composition))
  prediction = next(
      make_predictions(
          iter([inference_input]),
          postprocess_tokens=paligemma_tokenizer.postprocess_tokens,
          batch_size=1,
          num_examples=1,
      )
  )
  parsed_prediction = data_processing.DocumentEditingLabel.from_output(
      prediction, loc_tokens=True
  )
  if not parsed_prediction:
    notebook_canvas.set_debug_output(f'❌ (could not parse) {prediction}', '')
    return data_processing.DocumentEditingLabel(
        gesture='none',
        bbox=document_editing.BoundingBox(top=0,left=0, bottom=0, right=0),
        text=''
    )

  scale = _IMAGE_RESOLUTION / paligemma_tools.LOCATION_TOKENS_RANGE_MAX

  # Show the predicted bounding box as an overlay to the input composition.
  draw = ImageDraw.Draw(composition)
  if parsed_prediction.gesture == 'point':
    draw.circle(
        (
            parsed_prediction.bbox.left * scale,
            parsed_prediction.bbox.top * scale
        ),
        fill='fuchsia',
        radius=8
    )
    draw.circle(
        (
            parsed_prediction.bbox.right * scale,
            parsed_prediction.bbox.bottom * scale
        ),
        fill='fuchsia',
        radius=8
    )
  else:
    draw.rectangle(
        [
            (
                parsed_prediction.bbox.left * scale,
                parsed_prediction.bbox.top * scale,
            ),
            (
                parsed_prediction.bbox.right * scale,
                parsed_prediction.bbox.bottom * scale,
            ),
        ],
        outline="fuchsia",
        width=2,
    )

  # Show a view of the composition with the predicted bounding box on the side panel.
  composition_image_url = rendering.to_data_url(composition)
  notebook_canvas.set_debug_output('✅ ' + parsed_prediction.to_string(data_processing.BBOX_FORMAT), composition_image_url)

  # Convert the predicted bounding box back to image space.
  parsed_prediction.bbox = rendering.bbox_to_image_space(
      parsed_prediction.bbox, gesture_area
  )

  return parsed_prediction

In [None]:
one_page = document_editing.load_page(PAGES_DIR, '9790964376811024979')

# This makes the word bounding box align more closely to the matplotlib
# rendering used by the colab as opposed to the original element sizes in the
# webpage renderings.
one_page.tighten_bboxes_for_colab_canvas()

In [None]:
canvas = notebook_canvas.Canvas(one_page, canvas_predict_fn, canvas_max_width=800, canvas_max_height=1600)
canvas.display_interaction_widget()

# Section 2 - Few-shot with Gemini

In this part, we will focus on trying to use bigger Foundational Models (e.g. Gemini 2.0) to solve the document editing task through 0-shot/few-shot approaches. The reason why we are not re-using a model from the PaLiGemma family is because these models haven't been trained for instruction-following and are therefore unlikely to work well.

## Data Loading

We will load the training and validation dataset for document editing. We will use the training set to sample the different few-shot examples, and run the evaluation on the validation dataset. The few-shot examples are sampled in a stratified way (where stratas are defined as the gesture types), to ensure the model sees examples of each of the classes.

⏩ You do not need to interact with the code in the following cells, you may therefore keep them collapsed to keep the colab as readable as possible.

In [None]:
# Validation dataset
eval_samples = valid_examples.values()
valid_dataset = tf.data.Dataset.from_tensor_slices(
    {
      'ink_hash': np.array([example["ink_hash"] for example in eval_samples]),
      'image/encoded': np.array([example["image"] for example in eval_samples]),
      'label': np.array([example["label"] for example in eval_samples]),
      'image_width': np.array([example["original_image"].width for example in eval_samples]),
      'image_height': np.array([example["original_image"].height for example in eval_samples]),
    }
)

# Train dataset
train_samples = train_examples.values()
train_dataset = tf.data.Dataset.from_tensor_slices(
    {
      'ink_hash': np.array([example["ink_hash"] for example in train_samples]),
      'image/encoded': np.array([example["image"] for example in train_samples]),
      'label': np.array([example["label"] for example in train_samples]),
      'image_width': np.array([example["original_image"].width for example in train_samples]),
      'image_height': np.array([example["original_image"].height for example in train_samples]),
    }
)

def get_gesture(example):
  return tf.strings.split(example["label"], " ", 1)[0]

def get_normalization_factor(example):
  return (example['image_width'] / 1000, example['image_height'] / 1000)


shot_sampler = sampling.StratifiedSampler(train_dataset, get_gesture)

## Model Loading

This cell loads the Gemini Model API Client and the corresponding inference function to be used later on to infer the prediction for a given document gesture.

✅ Please add a Gemini API Key through your secrets manager under the "🔑" in the left panel.

In [None]:
from google.colab import userdata
API_KEY = userdata.get("GOOGLE_API_KEY")
MODEL_NAME = "models/gemini-2.0-flash"

client = model_api.get_client(API_KEY)
inference_fn = model_api.client_to_inference_fn(client, MODEL_NAME)

## Few Shot Prompt Definition

In this section, we focus on defining the prompt format that will be used for querying Gemini in 0-shot/few-shot settings. We provide some helper classes, functions and overall template for defining the prompt, but feel free to play around with them in different manners to reach the best possible results!

The provided template is composed of 3 main parts:
- The `PROMPT` variable which corresponds to the main instruction given to the model.
- The `shot_prefix` variable which will add text to the prefix if there is at least 1 shot example provided.
- The `GestureFewShotPrompter.prepare_example` function which defines the format into which examples are to be generated to be added to the prompt (when `is_shot` is `True`) and used for formatting the current inference example (when `is_shot` is `False`).


Note: Gemini normalizes the coordinate system to be `([0, 1000], [0, 1000])` instead of `([0, width], [0, height])` and prefers being prompted with y coordinate first and x coordinate second.

⏩ Feel free to skip the content of these cells, since we only define helper methods therein.

In [None]:
#@title Helper class
GESTURES = (
    'crop',
    'delete',
    'insert',
    'instruct_image',
    'instruct_text',
    'question',
    'select',
    'underline',
)

class GestureFewShotPrompter(in_context_learning.FewShotPrompter):
  def prepare_example(self, example, is_shot):
    data = []

    image = self.load_image(example["image/encoded"])
    data.append(image)

    if is_shot:
      label = example["label"].decode()
      data.append(label)

    return data

## Prompt definition
✅ Please expand and the following cell to play around with the prompt.

In [None]:
PROMPT = (
f"""You receive as input an image that contains some text and a human gesture in red ink strokes.

Your task is to predict the type of gesture, the bounding box of the text it is annotating with a transcription of what the user wrote in handwriting (if anything was written) in the following format <gesture> <ymin> <xmin> <ymax> <xmax> [<transcription>].
The bounding box should have normalized coordinates as int [0, 1000). (0, 0) is the top left corner and the y, x coordinate values are relative values with respect to image height and width. Only output the gesture type, bounding box and text (if present) and nothing else.

The possible gestures are:
{os.linesep.join(f"- {gesture}" for gesture in GESTURES)}

"""
)

In [None]:
#@title Preview entire input to the model
shot_prefix = """Now, we show you some examples for each gesture type.

"""

example = next(iter(valid_dataset.as_numpy_iterator()))
n_shots = 2 # @param
shots = list(shot_sampler.sample(num_examples=n_shots))
shots_gemini = [
    data_processing.transform_example_label(
        example,
        data_processing.BBOX_PATTERN,
        data_processing.BBOX_FORMAT,
        get_normalization_factor(example)
    )
    for example in shots
]
prompter = GestureFewShotPrompter(PROMPT, shot_prefix, shots_gemini)
prompter.display_prompt(example)

## Fewshot inference

Based on the prompt you defined above, the following cells will run inference on a different number of shots. Feel free to modify your prompt to try and reach the best results!

**Note**: free AI Studio API Keys are limited to 15 Requests per model per minute, and 1500 Request per model per day, and the following cell will consume 800 of those requests!

In [None]:
results = {}
confusion_matrix = {}

N_SHOTS = [0, 1, 2, 4, 8]

for n_shots in N_SHOTS:
  shots = list(shot_sampler.sample(num_examples=n_shots))
  shots_gemini = [
      data_processing.transform_example_label(
          example,
          data_processing.BBOX_PATTERN,
          data_processing.BBOX_FORMAT,
          get_normalization_factor(example)
      )
      for example in shots
  ]
  prompter = GestureFewShotPrompter(PROMPT, shot_prefix, shots_gemini)
  targets, predictions = in_context_learning.infer_fewshot(
      inference_fn,
      prompter,
      valid_dataset,
      normalize_fn=get_normalization_factor
  )

  accuracies, ious, cers = metrics.compute_document_editing_metrics(targets, predictions)
  cm = metrics.confusion_matrix(predictions, targets)

  results[n_shots] = {
      'Accuracy': sum(accuracies) / len(accuracies),
      'IoU': sum(ious) / len(ious),
      'CER': sum(cers) / len(cers),
  }
  confusion_matrix[n_shots] = cm

metrics.plot_fewshot_results(results)

# Section 3 - Extending the model

In this section you will attempt to define a brand new gesture class and further finetune the existing model for recognizing it.

## Synthetic dataset generation
The trained model classifies and localizes a fixed set of classes. **What if we want to add a new class?**

In this section we will show how to generate additional data with new class – **arrow**. As a first step we need a source of arrow inks. We use the [MathWriting](https://github.com/google-research/google-research/tree/master/mathwriting) dataset, that contains handwritten math formulas and extract LaTeX symbols

$$\leftrightarrow, \Leftrightarrow$$

Let's load the dataset (may take around 5 minutes).

**Important:** if you've restarted the runtime and previously generated the dataset, you don't need to rerun this section. The data should be located in the directory `data/mathwriting-2024/arrow_dataset`

✅ You can modify dataset generation parameters, such as the margin around an arrow.

### Create dataset on the spot

In [None]:
# @title Load MathWriting dataset.

MATHWRITING_FILE_NAME = 'mathwriting-2024.tgz'
MATHWRITING_BASE_PATH = "data/mathwriting-2024"
MATHWRITING_FILE_PATH = os.path.join(MATHWRITING_BASE_PATH, MATHWRITING_FILE_NAME)

if not os.path.exists(MATHWRITING_BASE_PATH):
  !mkdir -p {MATHWRITING_BASE_PATH}

!wget -nc https://storage.googleapis.com/mathwriting_data/mathwriting-2024.tgz --no-check-certificate -O {MATHWRITING_FILE_PATH}
shutil.unpack_archive(MATHWRITING_FILE_PATH, 'data/')

In [None]:
# @title Extract arrows.

def get_symbol_ink(symbol: arrow_ink_tools.InkPart) -> document_editing.Ink:
  """Computes the actual ink from an InkPart object."""
  ink = arrow_ink_tools.read_inkml_file(
      os.path.join(MATHWRITING_BASE_PATH, "train", f"{symbol.source_sample_id}.inkml"))
  strokes = [ink.strokes[i] for i in symbol.stroke_indices]
  return document_editing.Ink(strokes=strokes)

import matplotlib.pyplot as plt
def plot_ink(ax: plt.Axes, ink: document_editing.Ink):
  """Plots the ink data on the given axes."""
  for stroke in ink.strokes:
    ax.plot(stroke.xs, stroke.ys, color="red", linewidth=2, zorder=100)

symbols = arrow_ink_tools.read_symbols_file(os.path.join(MATHWRITING_BASE_PATH, 'symbols.jsonl'))
arrows = [s for s in symbols if s.label in {'\\leftrightarrow', '\\Leftrightarrow'}]
print(f"We've extracted {len(arrows)} unique arrows from MathWriting dataset")

For each arrow we find **left and right critical points** – edges of an arrow.

In [None]:
arrows_downloaded = []
for i, arrow in enumerate(arrows):
  arrow_ink = arrow_ink_tools.ArrowInk(ink=get_symbol_ink(arrow))
  arrows_downloaded.append(arrow_ink)
  if i == 0:
    plt.close()
    plot_ink(ink=arrow_ink.ink, ax=plt)
    plt.plot(arrow_ink.left_critical_point.x, arrow_ink.left_critical_point.y, 'bo')
    plt.plot(arrow_ink.right_critical_point.x, arrow_ink.right_critical_point.y, 'bo')
    plt.show()

In [None]:
# @title Helper functions for sampling words and arrows.
def find_random_word_pair(page: document_editing.Page) -> tuple[int, int]:
  """Finds two random words that are close on the page."""
  word_ids = []
  for (id_, element) in page.element_from_id.items():
    if element.class_name == 'word':
      word_ids.append(id_)
  page_reader = rendering.render_document(page)

  possible_pairs = []
  for i in range(len(word_ids)):
    for j in range(i + 1, len(word_ids)):
      c1 = page_example.element_from_id[word_ids[i]].bbox.center
      x1, y1 = c1.x, c1.y
      c2 = page_example.element_from_id[word_ids[j]].bbox.center
      x2, y2 = c2.x, c2.y

      y_size = page_reader.extent.bottom - page_reader.extent.top
      x_size = page_reader.extent.right - page_reader.extent.left
      share_x = abs((x2 - x1) / x_size)
      share_y = abs((y2 - y1) / y_size)

      if share_x <= 0.1 and share_y <= 0.1:
        possible_pairs.append((word_ids[i], word_ids[j]))

  random_pair = random.choice(possible_pairs)
  y1 = page_example.element_from_id[random_pair[0]].bbox.center.y
  y2 = page_example.element_from_id[random_pair[1]].bbox.center.y

  if y1 >= y2:
    return (random_pair[1], random_pair[0])

  return random_pair

def sample_two_words_and_fit_an_arrow(page_example):
  page_with_tight_bboxes = copy.deepcopy(page_example)
  page_with_tight_bboxes.tighten_bboxes_for_colab_canvas()

  word_id1, word_id2 = find_random_word_pair(page_with_tight_bboxes)
  arrow = random.sample(arrows_downloaded, 1)[0]
  arrow_fitter = arrow_ink_tools.ArrowPageFitter(page_with_tight_bboxes, word_id1, word_id2)
  return arrow_fitter, arrow_fitter.fit_to_page(arrow=arrow, verbose=False)

Our next step is to choose two random words on the page that are relatively close together. We then fit an arrow between them by aligning one of its endpoints with the center of the first word, and then rotating and scaling the arrow until its other endpoint aligns with the center of the second word.

In [None]:
dataset = []
num_tries = 3
for valid_page_id in tqdm(valid_page_ids):
  page_example = valid_pages[valid_page_id]
  for _ in range(num_tries):
    arrow_fitter, final_arrow = sample_two_words_and_fit_an_arrow(page_example)
    if arrow_ink_tools.check_that_arrow_is_located_correctly(arrow_fitter, final_arrow):
      dataset.append((arrow_fitter, final_arrow))
      break

Next, we crop images around each arrow (with a margin) and save them together with the targets.

In [None]:
# @title Save images on disk and prepare targets.
GENERATED_DATASET_PATH = os.path.join(MATHWRITING_BASE_PATH, 'arrow_dataset')
if not os.path.exists(GENERATED_DATASET_PATH):
  ! mkdir {GENERATED_DATASET_PATH}

margin = 40 #@param {type: "integer"}
prepared_targets = []
for i, (arrow_fitter, final_arrow) in tqdm(enumerate(dataset)):
  center = final_arrow.ink.get_bbox().center
  x, y = center.x, center.y
  size = max(final_arrow.ink.get_bbox().width, final_arrow.ink.get_bbox().height) + margin
  crop_area = document_editing.BoundingBox(y - size // 2, x - size // 2, y + size // 2, x + size // 2)
  image = arrow_fitter.page.image_from_id[-1]
  image_with_ink = rendering.render_ink_on_image(
      final_arrow.ink,
      image,
      add_semi_transparent_overlay=True
  )
  cropped_image = image_with_ink.crop(
      (crop_area.left, crop_area.top, crop_area.right, crop_area.bottom)
  )
  cropped_image.save(os.path.join(GENERATED_DATASET_PATH, f'{i+1}.png'))
  prepared_targets.append(arrow_ink_tools.get_arrow_target(arrow_fitter, crop_area))

In [None]:
# @title Save jsonl files for datasets.
data_path = os.path.join(GENERATED_DATASET_PATH, 'data_train90.jsonl')
with open(data_path, 'w') as f:
  for i, target in enumerate(prepared_targets[:90]):
    json_data = {"prefix": "", "suffix": target, "image": f"{i+1}.png"}
    json.dump(json_data, f)
    f.write('\n')

data_path = os.path.join(GENERATED_DATASET_PATH, 'data_val10.jsonl')
with open(data_path, 'w') as f:
  for i, target in enumerate(prepared_targets[90:]):
    json_data = {"prefix": "", "suffix": target, "image": f"{i+1}.png"}
    json.dump(json_data, f)
    f.write('\n')

## Example with arrow on a page and target

In [None]:
example_id = 5 #@param {type: "integer"}
arrow_fitter, final_arrow = dataset[example_id]
page_reader = rendering.render_document(
      arrow_fitter.page,
      overlay_bboxes={'lime': arrow_fitter.word_id1_bbox, 'fuchsia': arrow_fitter.word_id2_bbox,},
      ink=final_arrow.ink)
page_reader.image

In [None]:
print(f'Example of prepared target: {prepared_targets[6]}')

## PaLIGemma fine-tuning on synthetic data

In this section we will further fine-tune the model for it to learn the new arrow-based gesture which will swap two words on a page. In training dataset we have the following mixture:

* 50% new arrow dataset
* 50% existing training dataset with other gestures

In [None]:
# @title Load train and validation datasets from jsonl format.

train_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join("data/mathwriting-2024/arrow_dataset/data_train90.jsonl"),
    fopen_keys={"image": GENERATED_DATASET_PATH})

val_dataset = big_vision.datasets.jsonl.DataSource(
    os.path.join("data/mathwriting-2024/arrow_dataset/data_val10.jsonl"),
    fopen_keys={"image": GENERATED_DATASET_PATH})

def original_train_data_iterator(train_examples):
  for example in train_examples.values():
    yield paligemma_gesture_preparation.prepare_train_input(
        paligemma_tokenizer, image=example['image'], suffix=example['label']
        )

def train_data_iterator():
  """Never ending iterator over training examples."""
  # Shuffle examples and repeat so one can train for many epochs.
  dataset = train_dataset.get_tfdata().shuffle(1_000).repeat()
  original_dataset = original_train_data_iterator(train_examples)
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example['image']))
    suffix = example['suffix'].decode()
    yield paligemma_gesture_preparation.prepare_train_input(paligemma_tokenizer, suffix=suffix, image=image)
    yield next(original_dataset)

def val_data_iterator():
  """Single iterator over validation examples.."""
  dataset = val_dataset.get_tfdata(ordered=True)
  for example in dataset.as_numpy_iterator():
    image = Image.open(io.BytesIO(example['image']))
    yield paligemma_gesture_preparation.prepare_inference_input(paligemma_tokenizer, image=image)

We check what the model currently outputs on arrows.

In [None]:
# @title Make predictions on arrows.
for pred in make_predictions(val_data_iterator(),
                     postprocess_tokens=paligemma_tokenizer.postprocess_tokens,
                     num_examples=8):
  print(data_processing.DocumentEditingLabel.from_string(pred))

We train the model the dataset with arrows

In [None]:
# @title Define the training step.
#
# The main update_fn using simple SGD.
#
@functools.partial(jax.jit, donate_argnums=(0,))
def update_fn(params, batch, learning_rate):
  imgs, txts, mask_ar = batch["image"], batch["text"], batch["mask_ar"]

  def loss_fn(params):
    text_logits, _ = model.apply({"params": params}, imgs, txts[:, :-1], mask_ar[:, :-1], train=True)
    logp = jax.nn.log_softmax(text_logits, axis=-1)

    # The model takes as input txts[:, :-1] but the loss is defined as predicting
    # next tokens txts[:, 1:]. Additionally, mask_loss[:, 1:] indicates which tokens
    # are part of the loss (e.g. prefix and padded tokens are not included).
    mask_loss = batch["mask_loss"][:, 1:]
    targets = jax.nn.one_hot(txts[:, 1:], text_logits.shape[-1])

    # Compute the loss per example. i.e. the mean of per token pplx.
    # Since each example has a different number of tokens we normalize it.
    token_pplx = jnp.sum(logp * targets, axis=-1)  # sum across vocab_size.
    example_loss = -jnp.sum(token_pplx * mask_loss, axis=-1)  # sum across seq_len.
    example_loss /= jnp.clip(jnp.sum(mask_loss, -1), 1)  # weight by num of tokens.

    # batch_loss: mean of per example loss.
    return jnp.mean(example_loss)

  loss, grads = jax.value_and_grad(loss_fn)(params)

  # Apply gradients to trainable params using SGD.
  def apply_grad(param, gradient, trainable):
    if not trainable: return param
    return param - learning_rate * gradient

  params = jax.tree_util.tree_map(apply_grad, params, grads, trainable_mask)

  return params, loss

In [None]:
# @title Run training loop.
#
# Run a short training loop with cosine learning rate schedule.
#
# Note: the first step can be quite slow on some machines (up to several minutes)
# due to XLA compilation of the jax.jit'd function.

BATCH_SIZE = 2 # @param {type:"integer"}
LEARNING_RATE = 0.0001  # @param {type:"number"}

TRAIN_STEPS = 64 # @param {type:"integer"}
EVAL_STEPS = 32
NUM_EVAL_EXAMPLES = 4

# collect valid targets
dataset = val_dataset.get_tfdata(ordered=True)
eval_targets = []
for example in dataset.as_numpy_iterator():
  eval_targets.append(example['suffix'])
  if len(eval_targets) == NUM_EVAL_EXAMPLES:
    break

train_data_it = train_data_iterator()

sched_fn = big_vision.utils.create_learning_rate_schedule(
    total_steps=TRAIN_STEPS+1, base=LEARNING_RATE,
    decay_type="cosine", warmup_percent=0.10)

for step in range(1, TRAIN_STEPS+1):
  # Make list of N training examples.
  examples = [next(train_data_it) for _ in range(BATCH_SIZE)]

  # Convert list of examples into a dict of np.arrays and load onto devices.
  batch = jax.tree.map(lambda *x: np.stack(x), *examples)
  #batch = big_vision.utils.reshard(batch, data_sharding)

  # Training step and report training loss
  learning_rate = sched_fn(step)
  params, loss = update_fn(params, batch, learning_rate)

  loss = jax.device_get(loss)
  print(f"step: {step:2d}/{TRAIN_STEPS:2d}   lr: {learning_rate:.5f}   loss: {loss:.4f}")

  if step == 1 or (step % EVAL_STEPS) == 0:
    print(f"Model predictions at step {step}")
    for pred, target in zip(make_predictions(val_data_iterator(),
                                 postprocess_tokens=paligemma_tokenizer.postprocess_tokens,
                                 num_examples=NUM_EVAL_EXAMPLES), eval_targets):
      print('predicition:', pred)
      print('target:      ', target.decode())

In [None]:
# @title Prediction on original validation.
predictions_after_finetune = []
valid_data_iter = valid_data_iterator(examples_to_display)
for pred in tqdm(make_predictions(valid_data_iter,
                     postprocess_tokens=paligemma_tokenizer.postprocess_tokens,
                     num_examples=16)):
  predictions_after_finetune.append(data_processing.DocumentEditingLabel.from_string(pred))

correct_class_after_finetune = []

for (page_id, example), prediction in zip(examples_to_display, predictions_after_finetune):
  if prediction is None:
    correct_class_after_finetune.append(False)
    continue
  correct_class_after_finetune.append(prediction.gesture == example['classname'])

print(f'Gesture classification accuracy: {np.mean(correct_class_after_finetune)}')
for i, pred in enumerate(predictions_after_finetune):
  print('prediction:', pred)
  print('target:    ', examples_to_display[i][1]['label'])
  print('-----------')