##### Copyright 2025 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# T5Gemma Example

<table align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Research/[T5Gemma]Example.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
</table>

We present [T5Gemma (aka encoder-decoder Gemma)](https://arxiv.org/abs/2504.06225), a family of encoder-decoder large langauge models, developed by adapting pretrained decoder-only models into encoder-decoder.

T5Gemma includes pretrained and instruction-tuned variants, each with two groups of scales:
* [Gemma 2 scale](https://ai.google.dev/gemma/docs/core/model_card_2): 2B-2B, 9B-2B, and 9B-9B.
* [T5 scale](https://arxiv.org/abs/1910.10683): Small, Base, Large, and XL. An additional ML scale model is added which is in-between T5 Large and T5 XL.

Find the model weights on [Hugging Face](https://huggingface.co/collections/google/t5gemma-686ba262fe290b881d21ec86) and [Kaggle](https://www.kaggle.com/models/google/t5gemma).

In this notebook, we walk you through how to sampling (and tuning) with T5Gemma Small using Flax and Huggingface.


# Hugging Face

## Hugging Face login

In [None]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Sampling

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

tokenizer = AutoTokenizer.from_pretrained("google/t5gemma-b-b-prefixlm-it")
model = AutoModelForSeq2SeqLM.from_pretrained("google/t5gemma-b-b-prefixlm-it")


chat_template = '<start_of_turn>user\n{user_input}<end_of_turn>\n<start_of_turn>model\n'
prompt = chat_template.format(
    user_input='Tell me an unknown interesting biology fact about the brain.'
)

input_ids = tokenizer(prompt, return_tensors="pt")
output = model.generate(**input_ids, max_new_tokens=128)

print(tokenizer.decode(output[0], skip_special_tokens=True))


tokenizer_config.json:   0%|          | 0.00/46.4k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/577 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

The brain is actually **not** a single, spherical structure. It's a complex network of interconnected neurons, each with its own unique structure and function. 

Think of it like a giant, interconnected network of neurons, each with its own unique role in processing information. 



# Flax


In [None]:
!pip install -q git+https://github.com/google-deepmind/gemma.git


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m61.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.6/486.6 kB[0m [31m32.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m55.4/55.4 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m400.4/400.4 kB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m65.3/65.3 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

## Imports

In [None]:
from etils import ecolab
import os
import optax
import treescope
import kagglehub


from kauldron import kd
from gemma import gm

from gemma.research import t5gemma


## Kaggle login

In [None]:
kagglehub.login()

VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


## Sampling


In [None]:
# t5gemma
preset = t5gemma.T5GemmaPreset.GEMMA2_BASE_BASE

t5gemma_model = preset.config.make('transformer')
t5gemma_ckpt = preset.get_checkpoint_from_kaggle(
    t5gemma.CKPTType.IT,
    t5gemma.PretrainType.PREFIXLM,
)

t5gemma_params = gm.ckpts.load_params(t5gemma_ckpt)

# sampling
sampler = t5gemma.Sampler(
  model=t5gemma_model,
  params=t5gemma_params,
  tokenizer=preset.tokenizer,
  max_input_length=64,
  max_output_length=32,
)

chat_template = '<start_of_turn>user\n{user_input}<end_of_turn>\n<start_of_turn>model\n'
ouptut = sampler.sample(
  chat_template.format(
    user_input='Tell me an unknown interesting biology fact about the brain.'
  ),
  max_new_tokens=32,
)

print(ouptut)

The brain is actually **not** a single, spherical structure. It's a complex network of interconnected neurons, each with its own unique structure and function.


## Finetuning

A simple example of finetuning encoder-decoder for machine translation.

### Preprocessor

Convert decoder-only format to encoder-decoder.

In [None]:
import dataclasses
from etils import enp
from grain import python as grain

@dataclasses.dataclass(kw_only=True, frozen=True)
class Deconly2EncDecPreprocessor(grain.MapTransform):

  in_input: kd.kontext.Key           # "input"
  in_target: kd.kontext.Key          # "target"
  in_loss_mask: kd.kontext.Key       # "loss_mask"

  out_encoder_input: kd.kontext.Key  # "encoder_input"
  out_decoder_input: kd.kontext.Key  # "decoder_input"
  out_target: kd.kontext.Key         # "target"
  out_loss_mask: kd.kontext.Key      # "loss_mask"

  pad_id: int = 0
  max_len: int | None = None


  def map(self, element):
    """Preprocess converting deconly example to encoder-decoder.

    Example:
      Deconly:
        Input:         <s>  A  B  C  1  2  3
        Target:         A   B  C  1  2  3 </s>
        Loss Mask:      0   0  0  1  1  1  1

      ==>
      Encoder-Decoder:
        Encoder Input:  A   B  C
        Decoder Input:           <s>  1  2  3
        Target:                   1  2  3 </s>
        Loss Mask:                1  1  1  1
    Args:
      element: input single example in a dictionary format.
    Returns:
      A dictionary of preprocessed examples for encoder-decoder modeling.
    """
    # Extract the values from the `dict` example.
    deconly_input = kd.kontext.get_by_path(element, self.in_input)
    deconly_target = kd.kontext.get_by_path(element, self.in_target)
    deconly_loss_mask = kd.kontext.get_by_path(element, self.in_loss_mask)

    xnp = enp.lazy.get_xnp(deconly_input, strict=False)

    deconly_target = deconly_target[..., 0]
    deconly_loss_mask = deconly_loss_mask[..., 0]
    deconly_input_mask = deconly_input != self.pad_id
    seq_len = deconly_input.shape[0]

    # Encoder input tokens
    # Encoder mask -> positions -> gather input tokens from positions
    # [1, 1, 1, 0, 0, 0, 0]
    encdec_encoder_input_mask = xnp.logical_and(
        ~deconly_loss_mask,
        deconly_input_mask,
    ).astype(xnp.int32)
    # We didn't subtract it by 1 due to skipping <s>
    # [1, 2, 3, 0, 0, 0, 0]
    encdec_encoder_input_positions = xnp.cumsum(
        encdec_encoder_input_mask, axis=-1
    ) * encdec_encoder_input_mask
    # To avoid input-only errors
    encdec_encoder_input_positions *= (
        encdec_encoder_input_positions < seq_len
    ).astype(xnp.int32)
    # [A, B, C, 0, 0, 0, 0]
    encdec_encoder_input_tokens = xnp.take_along_axis(
        deconly_input, encdec_encoder_input_positions, axis=-1
    ) * encdec_encoder_input_mask

    # Decoder input tokens
    # Decoder mask -> positions -> move to beginning by sorting -> gather tokens
    # [3]
    num_encoder_tokens = xnp.sum(
        encdec_encoder_input_mask, axis=-1, keepdims=True
    )
    # [0, 0, 0, 1, 1, 1, 1]
    encdec_decoder_mask = xnp.logical_and(
        deconly_loss_mask,
        deconly_input_mask,
    ).astype(xnp.int32)
    # [0, 0, 0, 1, 2, 3, 4]
    encdec_decoder_positions = xnp.cumsum(
        encdec_decoder_mask, axis=-1
    ) * encdec_decoder_mask
    # Invalid tokens are set to seq_len+1
    # [8, 8, 8, 1, 2, 3, 4]
    encdec_decoder_positions += (1 - encdec_decoder_mask) * (seq_len+1)
    # After sorting, all valid tokens are put into the beginning in order
    # [1, 2, 3, 4, 8, 8, 8]
    encdec_decoder_positions = xnp.sort(
        encdec_decoder_positions, axis=-1
    )
    # Valid tokens should have positions <= seq_len
    # [1, 1, 1, 1, 0, 0, 0]
    encdec_decoder_mask = (
        encdec_decoder_positions <= seq_len).astype(xnp.int32)
    # [4, 5, 6, 7, 11, 11, 11]
    encdec_decoder_positions += num_encoder_tokens
    # [3, 4, 5, 6, 0, 0, 0]
    encdec_decoder_target_positions = (
        encdec_decoder_positions - 1
    ) * encdec_decoder_mask
    # The first token now changed to <s> for decoder input
    # [0, 4, 5, 6, 0, 0, 0]
    encdec_decoder_input_positions = xnp.pad(
        encdec_decoder_positions,
        ((1, 0)),
        'constant',
        constant_values=0,
    )[:-1]
    encdec_decoder_input_positions *= encdec_decoder_mask

    # [<s>, 1, 2, 3, 0, 0, 0]
    encdec_decoder_input_tokens = xnp.take_along_axis(
        deconly_input, encdec_decoder_input_positions, axis=-1
    ) * encdec_decoder_mask
    # [1, 2, 3, </s>, 0, 0, 0]
    encdec_decoder_target_tokens = xnp.take_along_axis(
        deconly_target, encdec_decoder_target_positions, axis=-1
    ) * encdec_decoder_mask

    max_len = self.max_len
    if max_len is None:
      max_len = seq_len

    # Add the fields to the output `dict`.
    # Equivalent to `element[self.out_input] = ...`
    kd.kontext.set_by_path(
        element,
        self.out_encoder_input,
        encdec_encoder_input_tokens[:max_len],
    )
    kd.kontext.set_by_path(
        element,
        self.out_decoder_input,
        encdec_decoder_input_tokens[:max_len],
    )
    kd.kontext.set_by_path(
        element,
        self.out_target,
        encdec_decoder_target_tokens[:max_len, None],
    )
    kd.kontext.set_by_path(
        element,
        self.out_loss_mask,
        encdec_decoder_mask[:max_len, None],
    )
    return element

ds = kd.data.py.Tfds(
    name='mtnt/en-fr',
    split='train',
    shuffle=True,
    batch_size=8,
    transforms=[
        # Create the model inputs/targets/loss_mask.
        gm.data.Seq2SeqTask(
            # Select which field from the dataset to use.
            # https://www.tensorflow.org/datasets/catalog/mtnt
            in_prompt='src',
            in_response='dst',
            # Output batch is {'input': ..., 'target': ..., 'loss_mask': ...}
            out_input='input',
            out_target='target',
            out_target_mask='loss_mask',
            tokenizer=preset.tokenizer,
            # Padding parameters
            max_length=200,
            truncate=True,
        ),
        Deconly2EncDecPreprocessor(
            in_input='input',
            in_target='target',
            in_loss_mask='loss_mask',
            out_encoder_input='encoder_input',
            out_decoder_input='decoder_input',
            out_target='target',
            out_loss_mask='loss_mask',
            pad_id=preset.tokenizer.special_tokens.PAD,
            max_len=200,
        ),
    ],
)

ex = ds[0]

treescope.show(ex)



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/mtnt/en-fr/1.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.2PGJ2G_1.0.0/mtnt-train.array_record*...:   0%|     …

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.2PGJ2G_1.0.0/mtnt-test.array_record*...:   0%|      …

Generating valid examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/mtnt/en-fr/incomplete.2PGJ2G_1.0.0/mtnt-valid.array_record*...:   0%|     …

Dataset mtnt downloaded and prepared to /root/tensorflow_datasets/mtnt/en-fr/1.0.0. Subsequent calls will reuse this data.


In [None]:
for k in ex:
  if k in ['decoder_input', 'encoder_input']:
    print(f"{k}\n```\n{preset.tokenizer.decode(ex[k][0])}```\n")

decoder_input
```
Est-ce que les femmes passent encore la nuit avant le mariage dans une autre chambre que celle de leur fiancée ???```

encoder_input
```
<start_of_turn>user
Do woman still spend the night before their wedding away from their fiancee???<end_of_turn>
<start_of_turn>model
```



### Trainer

Based on [kauldron](https://kauldron.readthedocs.io/en/latest/), following Gemma.


In [None]:
loss = kd.losses.SoftmaxCrossEntropyWithIntLabels(
    logits="preds.logits",
    labels="batch.target",
    mask="batch.loss_mask",
)

In [None]:
model = preset.config.make(
    "transformer",
    input_tokens="batch.encoder_input",
    target_tokens="batch.decoder_input",
)

checkpoint = preset.get_checkpoint_from_kaggle(
    t5gemma.CKPTType.IT,
    t5gemma.PretrainType.PREFIXLM,
)

trainer = kd.train.Trainer(
    seed=42,  # The seed of enlightenment
    workdir='/tmp/ckpts',
    # Dataset
    train_ds=ds,
    # Model
    model=model,
    # Params
    init_transform=gm.ckpts.LoadCheckpoint(checkpoint),
    # Training parameters
    num_train_steps=500,
    train_losses={"loss": loss},
    optimizer=optax.adafactor(learning_rate=1e-4),
    sharding=kd.sharding.ShardingStrategy(
        ds=kd.sharding.FIRST_DIM,
        params=kd.sharding.FSDPSharding(),
    )
)

In [None]:
state, aux = trainer.train()

Starting training loop at step 0


train:   0%|          | 0/501 [00:00<?, ?it/s]

### Sampling

In [None]:
sampler = t5gemma.Sampler(
    model=model,
    params=state.params,
    tokenizer=preset.tokenizer,
)

output = sampler.sample('<start_of_turn>user\nHello! My next holidays are in Paris.<end_of_turn>\n<start_of_turn>model\n')

print(output)

Bonjour ! Je vais faire des vacances en France.<end_of_turn>
