
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the [T5 codebase](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.md) (based on Mesh TensorFlow) in JAX and Flax.

# Getting Started

In the following Colab, we present an introductory tutorial to get you started interacting with the T5X codebase. In particular, we'll introduce the major components of the T5X codebase and get you started running training, inference, and evaluation on natural text inputs.

Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md).

In [None]:
import functools
import os

import clu.data.dataset_iterator
import tensorflow as tf
import jax
from jax import random
from jax.experimental import multihost_utils
import jax.numpy as jnp
from flax import linen
import numpy as np
import seqio
import t5.data
from t5.evaluation import metrics as t5_metrics

In [None]:
import t5x
from t5x import partitioning
from t5x import train_state as train_state_lib
from t5x import utils
from t5x.examples.t5 import network
from t5x.examples.scalable_t5 import network as scalable_network
from t5x.interactive_model import InteractiveModel
from t5x.interactive_model import get_batches_from_seqio
from t5x.interactive_model import InferenceType
import nest_asyncio
nest_asyncio.apply()

# T5X Components

Let's start by going over some of the major components of the T5X codebase: models, checkpoints, and partitioners.

We will define instances of some of these components in the following subsections before we use them to run training, inference, and evaluation.

## T5X Models
One of the primary contributions of the T5X codebase is its easy-to-use collection of models. 

The T5X codebase provides an abstract base class, [`BaseModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py), which should be subclassed to define specific model architectures. This abstraction allows us to flexibly extend the T5X framework to custom architectures. Importantly, the `BaseModel` and all subclasses are free from parallelism-related features (this is handled by the partitioner; see following sections).

The T5X codebase also provides several widely-used subclasses of the `BaseModel`, namely the [`EncoderDecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py) and the [`DecoderModel`](https://github.com/google-research/t5x/blob/main/t5x/models.py).

Importantly, the proposed structure of the `BaseModel`/all subclasses does not impose that the model be implemented in a specific framework. Instead, all subclasses of the `BaseModel` take in an `nn.Module` constructor argument, which is used to implement the architecture of the model. These modules can be built in Flax (e.g. [minimal T5](https://github.com/google-research/t5x/blob/main/t5x/examples/t5/network.py)) or on top of a layers library such as [Flaxformer](https://github.com/google/flaxformer). 

We've provided a sample model definition below. For this example, we will instantiate an `EncoderDecoderModel`, which will also require us to define input and output vocabularies, an optimizer, and a decode function. We'll use the minimal T5 module to implement our model architecture. 

In [None]:
# Define EncoderDecoderModel constructor args (except the module).
input_vocabulary=t5.data.get_default_vocabulary()
output_vocabulary=t5.data.get_default_vocabulary()
optimizer=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0, logical_factor_rules=t5x.adafactor.standard_logical_factor_rules())
decode_fn=functools.partial(t5x.decoding.temperature_sample, temperature=1.0, topk=40)

# Define a model using the minimal T5 module.
t5_module = network.Transformer(config=network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False))
model = t5x.models.EncoderDecoderModel(
    module=t5_module,
    input_vocabulary=input_vocabulary,
    output_vocabulary=output_vocabulary,
    optimizer_def=optimizer,
    decode_fn=decode_fn)

## Checkpoints

The T5X codebase also includes checkpoints for a wide variety of pre-trained T5X models. A full list of all publicly available checkpoints is available at https://github.com/google-research/t5x/blob/main/docs/models.md.

For the following example, we have selected a pretrained [T5 1.1 Small model](https://github.com/google-research/t5x/blob/main/docs/models.md) that has been additionally finetuned to answer natural questions using the (open domain) [Natural Questions benchmark](https://ai.google.com/research/NaturalQuestions). We use this finetuned checkpoint for this example in order to see improved performance on the natural question examples we will use for training/inference/evaluation later on. 

To restore our model from this checkpoint, we first define the path to our checkpoint and the `dtype` to restore.

In [None]:
# The checkpoint below is a T5-1.1-Small checkpoint (https://github.com/google-research/t5x/blob/main/docs/models.md) 
# that has additionally been finetuned on the (Open Domain) Natural Questions 
# benchmark (https://ai.google.com/research/NaturalQuestions).
checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
dtype='bfloat16'

We also need to define how we want to restore our model. There are two different restore modes that are available in T5X; for now, we will use "specific", which will load the most recent checkpoint in the directory specified by `checkpoint_path`.

In [None]:
restore_mode='specific'

Finally, it should be noted that if you are restoring your model from a checkpoint, then the model architecture you defined above must match the model architecture of your checkpoint. For all T5X checkpoints listed at https://github.com/google-research/t5x/blob/main/docs/models.md, you can find the correct architecture for the given checkpoint in its corresponding Gin file.

## Partitioners

Partitioning is the process of dividing and replicating machine learning model parameters, activations, and data across accelerator devices in order to:


*   Train and infer from models too large to fit in the memory of a single device
*   Use extremely large batch sizes
*   Train faster

In T5X, partitioning is primarily provided through the [jax.pjit](https://github.com/google/jax/tree/main/jax/experimental/pjit.py) fronted via `PjitPartitioner`. `PjitPartitioner` has three primary constructor arguments:
*    `model_parallel_submesh`
*    `num_partitions`
*    `logical_axis_rules`

The `model_parallel_submesh` and `num_partitions` arguments provide two mutually exclusive methods of specifying the submesh of devices to use for model partitioning. If you specify `num_partitions`, T5X will use this value to generate a default `model_parallel_submesh` that is suitable, but may not be the optimal configuration. If you are interested in optimizing performance, you can try out different submeshes using the `model_parallel_submesh` parameter. For simplicity, we will use `num_partitions` in this Colab.

If you are interested in learning more about partitioning, please take a look at our T5X: Partitioning Deep Dive Colab (Colab status: WIP, link is upcoming).


In [None]:
partitioner=partitioning.PjitPartitioner(
        num_partitions=1,
        model_parallel_submesh=None)

# Running Training, Inference, and Evaluation

Now, let's get started running training, inference, and evaluation on natural text inputs. T5X provides an `InteractiveModel` class that we can wrap around our model, checkpoint, and partitioner components, enabling us to run training, inference, and evaluation in one line of code each.

The InteractiveModel requires a couple of additional constructor arguments, namely:


1.   `batch_size`: the number of examples per batch for training, inference, and evaluation.
2.   `task_feature_lengths`:  `task_feature_lengths` is a dictionary mapping the task feature key to the maximum length (int) for that feature. If a feature is longer than this length after preprocessing, the feature will be truncated. May be set to `None` to avoid truncation. \
For context, task features are specific to tasks (ex: inputs and targets), and can be mapped to various model-specific features (for example, if we are using a decoder-only model, the concatenation of inputs and targets will be mapped to `decoder_target_tokens`, the model features). This mapping is done by the model's feature converter.
3.   `output_dir`: Path to directory where we will write new model checkpoints.
4.   `input_shapes`: a mapping from key to array shape for each model feature in the global (unsharded) input batch. These input shapes are used to define and initialize the train state. Importantly, these input shapes define the *model features* shape, in contrast to the task features described above.

We define these arguments and an instance of the InteractiveModel below. Importantly, it should be noted that the InteractiveModel handles restoring our model from the provided checkpoint path, so once we instantiate the InteractiveModel, we will be ready to run training, inference, and evaluation. Restoring the model from a checkpoint may take a minute or two.



In [None]:
batch_size=8
task_feature_lengths = {'inputs': 38, 'targets': 18}
output_dir='/tmp/output_dir'
input_shapes = {
    'encoder_input_tokens': np.array([8, 38]),
    'decoder_target_tokens': np.array([8, 18]),
    'decoder_input_tokens': np.array([8, 18]),
    'decoder_loss_weights': np.array([8, 18])
}

interactive_model = InteractiveModel(
  batch_size=batch_size,
  task_feature_lengths=task_feature_lengths,
  output_dir=output_dir,
  partitioner=partitioner,
  model=model,
  dtype=dtype,
  restore_mode=restore_mode,
  checkpoint_path=checkpoint_path,
  input_shapes=input_shapes
)

Next, let's define some examples that we want to use for training/inference/evaluation. These examples should either be a list of inputs, or a list of dictionaries mapping 'target'/'input' keys to corresponding values, as shown below. We will define two sets of examples: one set to be trained on, and one set to run inference/evaluation on.

We are using natural question/answer pairs for our examples. As described in the [T5 paper](https://arxiv.org/abs/1910.10683), we must add a task-specific prefix to our input before we feed it to the model in order to specify what task we should perform on the provided input. For natural questions, we use the "nq question:" prefix.

In [None]:
training_examples = [
  {
      'input': 'nq question: who has been appointed as the new chairman of sebi',
      'target': 'Ajay Tyagi'
  }, 
  {
      'input': 'nq question: who wrote the book lion the witch and the wardrobe',
      'target': 'C. S. Lewis'
  }, 
  {
      'input': 'nq question: how many planes did japan lose at pearl harbor',
      'target': '29'
  }, 
  {
      'input': 'nq question: who does the voice of mcgruff the dog',
      'target': 'Jack Keil'
  }, 
  {
      'input': 'nq question: who sings the wheels in the sky keep on turning',
      'target': 'Journey'
  }, 
  {
      'input': 'nq question: who voices regina in glitter force doki doki',
      'target': 'Kumiko Watanabe'
  }, 
  {
      'input': 'nq question: when did the us become allies with britain',
      'target': 'during World War II'
  }, 
  {
      'input': 'nq question: who won the rugby 7 in las vegas',
      'target': 'the United States'
  },
]

validation_examples = [
  {
      'target': 'Joe Biden', 
      'input':'nq question: who is the president of the united states'
  }, 
  {
      'target': 'F. Scott Fitzgerald', 
      'input': 'nq question: who wrote the book the great gatsby'}, 
  {
      'target': '1914', 
      'input': 'nq question: in what year did the first world war begin'}, 
  {
      'target': 'Idina Menzel', 
      'input': 'nq question: who does the voice of elsa in Frozen'}, 
  {
      'target': 'Taylor Swift', 
      'input': 'nq question: who sings shake it off'}, 
  {
      'target': 'Tom Kenny', 
      'input': 'nq question: who voices spongebob squarepants'}, 
  {
      'target': '2010', 
      'input': 'nq question: when did the great british bake off start'}, 
  {
      'target': 'the Philadelphia Eagles', 
      'input': 'nq question: who won the superbowl in 2018'},
]

Now, we can run training, inference and evaluation on these examples with a single line of code for each task. Below, we run training and inference (evaluation requires a few more arguments, so we go over evaluation in a following section). This may take ~60 seconds.

In [None]:
interactive_model.train_step(examples=training_examples)
print(f"Training Summary: {interactive_model.train_summary}\n")
print(f"Step Number: {interactive_model.step}\n")

examples_and_predictions, _ = interactive_model.predict_with_aux(examples=validation_examples)
predictions = [prediction for example, prediction in examples_and_predictions]
print(f"Predictions: {predictions}\n")

examples_and_scores = interactive_model.score(examples=validation_examples)
scores = [score for example, score in examples_and_scores]
print(f"Scores: {scores}\n")

Alternately, you can run a training/inference/evaluation loop over multiple batches. The training loop below runs training and inference for each step, using the provided batches, and returns the predictions and scores from the final step. This may take ~60 seconds (note: if you use XL or XXL model sizes, this loop may take a while to complete; we are working on improved compilation strategies that optimize for runtime in b/247170488).

In [None]:
second_batch_of_examples = [
    {
        'input': 'nq question: who won the most academy awards in his lifetime',
        'target': 'Walt Disney'
    }, 
    {
        'input': 'nq question: who starred in the hand that rocks the cradle',
        'target': 'Rebecca De Mornay'
    }, 
    {
        'input': 'nq question: what does a red license plate mean in ontario',
        'target': 'diplomat'
    }, 
    {
        'input': 'nq question: who sang i dreamed a dream on britain\'s got talent',
        'target': 'Susan Magdalane Boyle'
    }, 
    {
        'input': 'nq question: when is season 7 of game of thrones being released',
        'target': 'August 27, 2017'
    }, 
    {
        'input': 'nq question: when is anne with an e season two coming out',
        'target': 'in 2018'
    }, 
    {
        'input': 'nq question: when was hard rock hotel las vegas built',
        'target': '1995'
    }, 
    {
        'input': 'nq question: what type of reaction leads to the production of polymers',
        'target': 'condensation reaction'
    }
]
all_training_batches = [training_examples, second_batch_of_examples]
examples_and_predictions, examples_and_scores, _ = interactive_model.train_loop(num_steps=2, train_batches=all_training_batches, predict_batches=[validation_examples], score_batches=[validation_examples])

print("\n All Predictions")
for example, prediction in examples_and_predictions:
  print(f"Input: {example['inputs_pretokenized']}, Prediction: {prediction}")
print("\nAll Scores:")
for example, score in examples_and_scores:
  print(f"Input: {example['inputs_pretokenized']}, Score: {score}")

### Preprocessors

By default, the only preprocessors that the methods above run are [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py). If you would like to use different preprocessors, you can do so using the `train_step_with_preprocessors` or `infer_with_preprocessors` methods. We've provided a sample below:

In [None]:
preprocessors = [
    seqio.preprocessors.tokenize,
    seqio.preprocessors.append_eos
]

interactive_model.train_step_with_preprocessors(examples=training_examples, preprocessors=preprocessors)
print(f"Training Summary: {interactive_model.train_summary}\n")
print(f"Step Number: {interactive_model.step}\n")

# Note: when we use a custom list of preprocessors, we must use a general 
# `infer` method, rather than `predict` or `score`. Thus, we must also specify 
# the type of inference to do; valid options are `PREDICT_WITH_AUX`, 
# or `SCORE`.
examples_and_predictions, _ = interactive_model.infer_with_preprocessors(
    mode=InferenceType.PREDICT_WITH_AUX, 
    examples=validation_examples, 
    preprocessors=preprocessors)
predictions = [prediction for example, prediction in examples_and_predictions]
print(f"Predictions: {predictions}\n")

examples_and_scores, _ = interactive_model.infer_with_preprocessors(
    mode=InferenceType.SCORE, 
    examples=validation_examples, 
    preprocessors=preprocessors)
scores = [score for example, score in examples_and_scores]
print(f"Scores: {scores}\n")

Because we use the same set of preprocessors, we should expect to see the same results as before.

If you are interested in learning more about preprocessors, please see [this preprocessors guide](https://github.com/google/seqio/blob/main/README.md#preprocessors), which also contains links to implementations of common preprocessors.

### Evaluation and Metrics Functions

We can similarly run evaluation in a single line. Running evaluation requires that we specify a metric function and (optionally) a list of postprocessors to run on the data before we compute metrics. 

There are a variety of sample metrics defined in [t5/evaluation/metrics.py](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/evaluation/metrics.py). For this example, we will use the [SQuAD metric function](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/evaluation/metrics.py) defined in this file. Because we are using natural questions, we will also specify a postprocessor to correctly format question and answer pairs for metrics calculations; specifically, we will use the [`t5.data.postprocessors.qa` method](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/postprocessors.py). We will continue to use the same preprocessors.

In [None]:
metrics = interactive_model.evaluate_with_preprocessors(
        examples=validation_examples,
        preprocessors=preprocessors,
        metric_fns=[t5_metrics.squad],
        postprocessor=t5.data.postprocessors.qa)
print(f"Metrics: {metrics}")

If you are interested in learning more about metrics functions or postprocessors, please see this [Metrics guide](https://github.com/google/seqio/blob/main/README.md#metrics ) and/or this [Postprocessors guide](https://github.com/google/seqio/blob/main/README.md#postprocessor).

# Advanced Topics

## SeqIO

If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets. 

We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.

<a href="https://colab.research.google.com/github/google/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the InteractiveModel above. We've provided an example of this bridge below. 

In [None]:
!git clone https://github.com/google-research/google-research.git
!cp -r google-research/t5_closed_book_qa/ ./
import t5_closed_book_qa.t5_cbqa.tasks

In [None]:
batches = get_batches_from_seqio(
        task_or_mixture_name='natural_questions_open',
        split='validation',
        batch_size=8,
        num_batches=2,
        seed=42)
print(f"Batches: {batches}")
# Train the interactive model on the provided batches.
original_step = interactive_model.step
_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")

The `get_batches_from_seqio` bridge can take several constructor arguments:


1.   `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).
2.   `split`: the split of the Task/Mixture to read data from.
3.   `batch_size`: how many examples should appear in each batch.
4.   `num_batches`: the total number of batches to return.
5.   `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \
The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \
However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.
6.   `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.
7.   `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.



In [None]:
import t5.data.tasks
batches = get_batches_from_seqio(
    task_or_mixture_name='c4_v220_span_corruption',
    split='validation',
    batch_size=8,
    num_batches=1,
    get_pretokenized_examples=False,
    sequence_length=interactive_model._task_feature_lengths,
    seed=42)
batch = batches[0]  # We expect only a single batch.
original_step = interactive_model.step
interactive_model.train_step_with_preprocessors(
        examples=batch, preprocessors=[])
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")