
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This is the second Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro, or have a basic understanding of the T5X models, checkpoints, partitioner, and `InteractiveModel`.

In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series, we presented a quick and easy way to use the `InteractiveModel` to run training on natural text inputs in only a few lines of code. In this Colab, we will dive into how the `InteractiveModel` restores models from checkpoints and runs training, while also getting an introduction to the T5X trainer. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `train_step()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to train a model.

# Set-Up

Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md).

In [None]:
from collections.abc import Sequence
import enum
import functools
import inspect
import itertools
import logging
import os
import re
from typing import Any, Callable, Iterator, Optional, Tuple, Union

import jax
from jax import random
from jax.experimental import multihost_utils
import numpy as np
import seqio
import tensorflow as tf
import tensorflow_datasets as tfds
import t5.data

In [None]:
import clu.data
from t5x.examples.t5 import network
import t5x
from t5x import models
from t5x import partitioning
from t5x import trainer as trainer_lib
from t5x import utils
from t5x.infer import _extract_tokens_and_aux_values
from t5x.infer import _Inferences
from t5x.interactive_model import InteractiveModel
from t5x.interactive_model import get_batches_from_seqio
from t5x.interactive_model import get_dataset_from_natural_text_examples
from t5x.interactive_model import get_gin_config_from_interactive_model
from t5x.interactive_model import T5XScriptType

Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` restores models from checkpoints and runs training.

If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series.

In [None]:
# Define a model. The configuration below corresponds to the T5 1.1 Small model.
t5_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False)
module = network.Transformer(config=t5_config)
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))
# Define checkpoint arguments.
checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
dtype='bfloat16'
restore_mode='specific'
# Define a partitioner.
partitioner=partitioning.PjitPartitioner(num_partitions=2)
# Define additional, miscellaneous constructor arguments.
batch_size=8
task_feature_lengths = {'inputs': 38, 'targets': 18}
output_dir='/tmp/output_dir'
input_shapes = {
    'encoder_input_tokens': np.array([8, 38]),
    'decoder_target_tokens': np.array([8, 18]),
    'decoder_input_tokens': np.array([8, 18]),
    'decoder_loss_weights': np.array([8, 18])
}

# Training Deep Dive

Let's start by going over what happens when we initialize the InteractiveModel.

The `InteractiveModel` `__init__()` method performs six main actions:


1.   Configure and possibly create an output directory.
2.   Initialize RNGs.
3.   Validate the partitioner.
4.   Create a checkpoint manager.
5.   Restore the model from a checkpoint or initialize from scratch.
6.   Create a trainer.



**Configuring the Output Directory** \
There is minimal work required to configure the output directory for our model: we simply remove double-slashes in the directory path to avoid inconsistencies and create the directory if it doesn't already exist.

In [None]:
output_dir = re.sub(r"(?<!gs:)([\/]{2,})", "/", output_dir)
if not os.path.exists(output_dir):
  os.mkdir(output_dir)

**Initializing RNGs** \
Initializing RNGs is made fairly straightforward with the use of JAX random operations.



We first set an initial seed using the `multihost_utils` tools, then define an RNG using the JAX `PRNGKey` utils, and finally split this RNG into two values: one each for initializing the model and training the model. This ensures that we never reuse an RNG key.

In [None]:
init_random_seed = 42
random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))
utils.set_hardware_rng_ops()
rng = random.PRNGKey(random_seed)
init_rng, trainer_rng = random.split(rng, 2)

**Validating the Partitioner** \

Because we've already constructed the partitioner, we simply need to validate that it was constructed properly. In particular, we need to ensure that the number of partitions created by the partitioner can easily divide the total number of devices.

In [None]:
if partitioner._model_parallel_submesh:
  num_partitions = np.prod(partitioner._model_parallel_submesh)
else:
  num_partitions = partitioner._num_partitions
if jax.device_count() % num_partitions != 0:
  raise ValueError(
    "The number of devices available must be a multiple of the number of",
    f" partitions. There are {jax.device_count()} devices available, but",
    f" the number of partitions is set to {num_partitions}. Please",
    " provide a different number of partitions.")

**Create a Checkpoint Manager**

We make use of the T5X [`LegacyCheckpointManager`](https://github.com/google-research/t5x/blob/main/t5x/utils.py) to restore our model and save any future checkpoints. The `LegacyCheckpointManager` requires several constructor arguments:



1.   `save_checkpoint_cfg`: an instance of the `SaveCheckpointConfig` wrapper class, which contains information about where and how to save future checkpoints.
2.   `restore_checkpoint_cfg`: an instance of the `RestoreCheckpointConfig` wrapper class, which contains information and where and how to load checkpoints and restore the model.
3.   `train_state_shape`: our model will load and save a T5X [`TrainState`](https://github.com/google-research/t5x/blob/main/t5x/train_state.py), which (as the name implies) stores information about the current state of training. We provide information about the shape of this train state to the checkpoint manager to enable saving this train state in checkpoints.
4.   `partitioner`: our predefined partitioner.
5.   `model_dir`: our previously configured output directory, where we will save any future checkpoints.

Before we define these constructor arguments and initialize the checkpoint manager, let's discuss the T5X `TrainState` in a bit more depth. Importantly, T5X is a JAX-based library, which means that all of our methods follow typical functional programming patterns.


Specifically, our training methods cannot have side effects, so we pass all model parameters, step number, optimizer state, etc. as input and get updated values as output from our methods. We use the T5X `TrainState` to hold all our model parameters, step number, optimizer state, etc. and we will later define a `train_step` method that will take in the train state and return an updated train state with new values.

We define these constructor arguments and initialize the checkpoint manager below.



In [None]:
# Define CheckpointCfg wrappers.
save_checkpoint_cfg = utils.SaveCheckpointConfig(
        dtype=dtype,
        keep=5, # The number of checkpoints to keep in the output_dir.
        save_dataset=False)
restore_checkpoint_cfg = utils.RestoreCheckpointConfig(
        dtype=dtype,
        mode=restore_mode,
        path=checkpoint_path)

# Define a train state initializer, which will help us get information about the
# TrainState shape.
train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=model.optimizer_def,
        init_fn=model.get_initial_variables,
        input_shapes=input_shapes,
        input_types=None,
        partitioner=partitioner)

checkpoint_manager = utils.LegacyCheckpointManager(
        save_cfg=save_checkpoint_cfg,
        restore_cfg=restore_checkpoint_cfg,
        train_state_shape=train_state_initializer.global_train_state_shape,
        partitioner=partitioner,
        ds_iter=None,
        model_dir=output_dir)

**Restore the Model from a Checkpoint or Initialize from Scratch** \

We try two different strategies for restoring a model. First, we try to restore the model from a checkpoint using the `CheckpointManager`. If no checkpoint can be found (likely because no path was provided in `checkpoint_path`), then we will initialize the model from scratch.

Finally, we will log model initialization information (such as parameter shape, partitioning annotation, etc.) to the output directory.

In [None]:
def get_state(rng):
  return train_state_initializer.from_scratch(rng).state_dict()

# 1. Try to restore a model from a checkpoint.
train_state = checkpoint_manager.restore(
  [restore_checkpoint_cfg.path],
  restore_checkpoint_cfg,
  utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)
)

# 2. If no checkpoint to restore, init from scratch.
if train_state is None:
  train_state = train_state_initializer.from_scratch(init_rng)

# Validate that we got an expected form of TrainState.
if isinstance(train_state, Sequence):
  raise ValueError(
    "Expected a single train state, but instead received a Sequence.")
train_state_axes = train_state_initializer.train_state_axes

# Log the variable shapes information and write to a file.
log_file = os.path.join(output_dir, "model-info.txt")
utils.log_model_info(log_file,
                     train_state_initializer.global_train_state_shape,
                     partitioner)

**Create a Trainer**

Finally, we use many of the parameters we've defined above to create an instance of the T5X [Trainer](https://github.com/google-research/t5x/blob/main/t5x/trainer.py). The trainer takes in several constructor arguments:



1.   `model`: the model that will be trained
2.   `train_state`: a train state with parameters and optimizer state, which we've restored or initialized above.
3.   `partitioner`: the partitioner to use.
4.   `summary_dir`: the output directory, where we can write summaries of training.
5.   `train_state_axes`: partitioning information for the optimizer, which we've initialized above.
6.   `rng`: the JAX RNG to be used for training.
7.   `learning_rate_fn`: a function that returns the learning rate given the current step. T5X provides some helper functions that define common learning rate schedules; we will use one of these helpers to define the learning rate in our example.

We initialize a sample Trainer below.



In [None]:
trainer = trainer_lib.Trainer(
  model=model,
  train_state=train_state,
  partitioner=partitioner,
  eval_names=[],
  summary_dir=output_dir,
  train_state_axes=train_state_axes,
  rng=trainer_rng,
  learning_rate_fn=utils.create_learning_rate_scheduler(),
  num_microbatches=None)

The code snippets above exactly replicate the `InteractiveModel` `__init__()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running the single code snippet below.

In [None]:
interactive_model = InteractiveModel(
  batch_size=batch_size,
  task_feature_lengths=task_feature_lengths,
  output_dir=output_dir,
  partitioner=partitioner,
  model=model,
  dtype=dtype,
  restore_mode=restore_mode,
  checkpoint_path=checkpoint_path,
  input_shapes=input_shapes
)

**Defining a Batch of Examples to Train On**\
We are now ready to begin training!

First, we'll begin by defining a batch of examples to train on; these examples should either be a list of inputs, or a list of dictionaries mapping 'target'/'input' keys to corresponding values, as shown below. For this Colab, we'll use a set of natural test questions and answers.

In [None]:
examples = [
  {
      'target': b'Ajay Tyagi',
      'input':b'nq question: who has been appointed as the new chairman of sebi'
  },
  {
      'target': b'C. S. Lewis',
      'input': b'nq question: who wrote the book lion the witch and the wardrobe'},
  {
      'target': b'29',
      'input': b'nq question: how many planes did japan lose at pearl harbor'},
  {
      'target': b'Jack Keil',
      'input': b'nq question: who does the voice of mcgruff the dog'},
  {
      'target': b'Journey',
      'input': b'nq question: who sings the wheels in the sky keep on turning'},
  {
      'target': b'Kumiko Watanabe',
      'input': b'nq question: who voices regina in glitter force doki doki'},
  {
      'target': b'during World War II',
      'input': b'nq question: when did the us become allies with britain'},
  {
      'target': b'the United States',
      'input': b'nq question: who won the rugby 7 in las vegas'},
]

We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below.

In [None]:
output_features = {
        "inputs":
            seqio.Feature(
                vocabulary=model.input_vocabulary, add_eos=True),
        "targets":
            seqio.Feature(
                vocabulary=model.output_vocabulary, add_eos=True)
    }
features = dict(sorted(output_features.items()))

Now, let's (similarly) break down what the interactive model does when it takes a single step of training.

The `InteractiveModel` `train_step()` method only performs two actions:


1.   Convert the natural text examples into a tf.Dataset.
2.   Take a single training step, using the T5X Trainer.



**Prepare the dataset** \

Preparing the data for training is fairly straightforward. First, we validate that enough examples have been provided to train on a full batch of data.

Then, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series.

Finally, we  convert all features using the model's feature converter and pad all batches of data.

In [None]:
# Validate num examples.
if len(examples) < batch_size:
  raise ValueError(
    "At least one batch of data must be provided. Please decrease the "
    "batch_size or provide more examples.")
# Get a tf.Dataset.
train_dataset = get_dataset_from_natural_text_examples(
    examples=examples,
    preprocessors=[
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos
    ],
    task_feature_lengths=task_feature_lengths,
    features=features)

# Convert and pad features.
feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)
train_dataset = feature_converter(
        train_dataset, task_feature_lengths=task_feature_lengths)
train_dataset = train_dataset.padded_batch(batch_size, drop_remainder=True)
train_iter = clu.data.dataset_iterator.TfDatasetIterator(train_dataset, checkpoint=False)

**Run 1 Training Step** \

We'll define a helper function that takes a single train step, making it easy to loop over this helper and train for multiple steps.

Training is made fairly straightforward because of the T5X trainer. We'll simply add some logic to validate that it's ok for training to occur and to save a checkpoint. In total, we'll perform the following actions:


1.   Validate that training can occur.
2.   Take a training step.
3.   Save a checkpoint.



In [None]:
def train_step(
    trainer: t5x.trainer.Trainer,
    train_state: t5x.train_state.TrainState,
    train_iter: clu.data.dataset_iterator.TfDatasetIterator,
    checkpoint_manager: utils.LegacyCheckpointManager,
    save_checkpoint_cfg: utils.SaveCheckpointConfig):
  # Validate that training can occur.
  if trainer.stop_training:
    logging.info("Stopping training early since `stop_training` is requested.")
    return

  # Take a training step.
  try:
    first_step = int(utils.get_local_data(train_state.step))
    train_summary = trainer.train(
      train_iter, 1, start_step=first_step)
  except trainer_lib.PreemptionError as e:
    logging.info("Saving emergency checkpoint.")
    checkpoint_manager.save(
      trainer.train_state,
      save_checkpoint_cfg.state_transformation_fns)
    logging.info("Saving emergency checkpoint done.")
    raise e

  # Save a checkpoint.
  logging.info("Saving checkpoint.")
  checkpoint_manager.save(
      trainer.train_state,
      save_checkpoint_cfg.state_transformation_fns)

  # Wait until computations are done before exiting
  multihost_utils.sync_global_devices("complete")
  return trainer.train_state, train_summary.result()

In [None]:
print(f"Current Step: {train_state.step}")
train_state, train_summary = train_step(trainer, train_state, train_iter, checkpoint_manager, save_checkpoint_cfg)
print(f"Current Step: {train_state.step}")
print(f"Summary of Training: {train_summary}")

The code snippets above exactly replicate the `InteractiveModel` `train_step()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.train_step(examples)`.

Alternately, you can loop over this helper function multiple times to finetune or pretrain a model (the code snippet below may take ~5 mins to run).

In [None]:
num_steps = 100
for _ in range(num_steps):
  # Reset the iterator, since we use the same batch for every step.
  train_iter = clu.data.dataset_iterator.TfDatasetIterator(train_dataset, checkpoint=False)
  train_state, train_summary = train_step(
      trainer,
      train_state,
      train_iter,
      checkpoint_manager,
      save_checkpoint_cfg
  )
print(f"Current Step: {train_state.step}")
print(f"Summary of Training: {train_summary}")

The code snippets above demonstrate how T5X runs training. You can exactly replicate this behavior by using the `InteractiveModel`, as described above.

# Advanced Topics

## T5X Training Binaries and Other Advanced Features

T5X offers training binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, custom checkpointing periods, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.

If you are familiar with Gin and interested in using the T5X training binaries, we have provided a helper function, `get_gin_config_from_interactive_model`, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X training binaries; this gin config will exactly reproduce the InteractiveModel training functionality we've described above. We've provided an example below.

Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO.

In [None]:
# Define an InteractiveModel instance, based on the `tiny` T5X EncoderDecoder model.
input_shapes = {
    'encoder_input_tokens': np.array([8, 38]),
    'decoder_target_tokens': np.array([8, 18]),
    'decoder_input_tokens': np.array([8, 18]),
    'decoder_loss_weights': np.array([8, 18])
}
t5_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=8,
    num_heads=4,
    num_encoder_layers=2,
    num_decoder_layers=2,
    head_dim=3,
    mlp_dim=16,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False)
module = network.Transformer(config=t5_config)
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
    decode_fn=functools.partial(
        t5x.decoding.temperature_sample, temperature=1.0, topk=40))
interactive_model = InteractiveModel(
      batch_size=8,
      task_feature_lengths={
            'inputs': 32,
            'targets': 32
        },
      output_dir='/tmp',
      partitioner=partitioning.PjitPartitioner(
        num_partitions=2,
        model_parallel_submesh=None,
        logical_axis_rules=partitioning.standard_logical_axis_rules()),
      model=model,
      dtype='float32',
      restore_mode='specific',
      checkpoint_path='',
      input_shapes=input_shapes,
      input_types=None)

# Define Gin Config strings for the model, partitioner, and any imports.
imports_str = """from t5x import models
from t5x import partitioning
import t5.data.mixtures
include 't5x/examples/t5/t5_1_1/tiny.gin'"""
partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'
model_config = """models.EncoderDecoderModel:
  z_loss = 0.0
  label_smoothing = 0.0
  loss_normalizing_factor = None"""

gin_config_str = get_gin_config_from_interactive_model(
    interactive_model=interactive_model,
    script_type=T5XScriptType.PRETRAINING,
    task_name='wmt19_ende_v003',
    partitioner_config_str=partitioner_config,
    model_config_str=model_config,
    train_steps=3,
    imports_str=imports_str)
print(gin_config_str)


Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your training experiment locally by running the following on commandline:


```
MODEL_DIR="/tmp/pretrain-model/"
python -m t5x.train \
  --gin_file=${GIN_FILE_PATH} \
  --gin.MODEL_DIR=\"${MODEL_DIR}\" \
  --alsologtostderr
```

For more details on training using the T5X training binaries, please see the [Pretraining](https://github.com/google-research/t5x/blob/main/docs/usage.md/pretrain) or [Finetuning](https://github.com/google-research/t5x/blob/main/docs/usage.md/finetune) tutorials.

## SeqIO

If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.

We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.

<a href="https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below.

In [None]:
!git clone https://github.com/google-research/google-research.git google_research

In [None]:
import google_research.t5_closed_book_qa.t5_cbqa.tasks
batches = get_batches_from_seqio(
        task_or_mixture_name='natural_questions_open',
        split='validation',
        batch_size=8,
        num_batches=2,
        seed=42)
print(f"Batches: {batches}")
# Train the interactive model on the provided batches.
original_step = interactive_model.step
_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")

The `get_batches_from_seqio` bridge can take several constructor arguments:


1.   `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).
2.   `split`: the split of the Task/Mixture to read data from.
3.   `batch_size`: how many examples should appear in each batch.
4.   `num_batches`: the total number of batches to return.
5.   `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \
The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \
However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.
6.   `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.
7.   `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.



In [None]:
import t5.data.tasks
batches = get_batches_from_seqio(
    task_or_mixture_name='c4_v220_span_corruption',
    split='validation',
    batch_size=8,
    num_batches=1,
    get_pretokenized_examples=False,
    sequence_length=interactive_model._task_feature_lengths,
    seed=42)
batch = batches[0]  # We expect only a single batch.
original_step = interactive_model.step
interactive_model.train_step_with_preprocessors(
        examples=batch, preprocessors=[])
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")