<!--BEGIN GOOGLE-INTERNAL-->
**GOOGLERS: To run internally, you will need to connect to a Brain Frameworks (TPU) Colab runtime. **

# Overview

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the [T5 codebase](https://g3doc.corp.google.com/third_party/py/t5/README.md?cl=head) (based on Mesh TensorFlow) in JAX and Flax.

# Getting Started

In the following sections, we present 4 quick tutorials to get you started with common use-cases on T5X:

1.)    **[DRAFT]** *Fine-tuning a Model:* This tutorial outlines the steps to fine-tune an existing pre-trained model with T5X on common downstream tasks/mixtures available on [SeqIO](go/seqio). This is one of the simplest and most common use cases of T5X. If you're new to T5X, this tutorial is the recommended starting point.

2.)    **[DRAFT]** *Running Evaluation on a Model:* This tutorial outlines the steps to evaluate a model with T5X on downstream tasks/mixtures defined in SeqIO.


3.)    **[DRAFT]** *Running Inference on a Model:* this tutorial outlines the steps to run inference on a model with T5X.

4.)    **[DRAFT]** *Training a Model from Scratch:* this tutorial outlines the steps to pretrain a model with T5X on tasks/mixtures defined in SeqIO.



**Note: Please run the "Set-up" section before beginning any of the tutorial sections.**

# Set-up

In the following section, we import all required modules and define two helper functions that will allow us to easily parse gin configs.

In [None]:
from typing import Sequence
import os
import re

import gin
import seqio
import functools
from google3.pyglib import gfile
from colabtools import adhoc_import
from colabtools import googlelog
with adhoc_import.Google3SubmittedChangelist(build_targets=['//t5x:train', '//t5x:eval', '//t5x:infer']):
  import t5x
  from t5x import train as train_lib
  from t5x.train import train 
  from t5x import eval as eval_lib
  from t5x.eval import evaluate 
  from t5x import infer as infer_script
  from t5x.infer import infer 
  from t5x import gin_utils
  from t5.data import mixtures
  from t5x import adafactor
  from t5x import examples
  from t5x.examples.t5 import network
  from t5x import utils

Below, we define a helper function that parses a Gin config string and any additional gin bindings. If you're not familiar with Gin, reading the [T5X Gin Primer](https://g3doc.corp.google.com/t5x/g3doc/usage/gin.md?cl=head) is recommended.

In [None]:
def parse_gin_strings(gin_search_paths: Sequence[str],
                      gin_file_str: str,
                      gin_bindings: Sequence[str],
                      skip_unknown: bool = False,
                      finalize_config: bool = True):
  """Parses provided gin files override params.

  Args:
    gin_search_paths: paths that will be searched for gin files.
    gin_file_strs: gin configs to be parsed. Strings will be parsed in order 
      with conflicting settings being overriden by later configs. Paths may
      be relative to paths in `gin_search_paths`.
    gin_bindings: individual gin bindings to be applied after the gin configs are
      parsed. Will be applied in order with conflicting settings being overriden
      by later opens.
    skip_unknown: whether to ignore unknown bindings or raise an error (default
      behavior).
    finalize_config: whether to finalize the config so that it cannot be
      modified (default behavior).
  """
  # Register .gin file search paths with gin
  for search_path in gin_search_paths:
    gin.add_config_file_search_path(search_path)

  # Parse config string and bindings.
  gin.parse_config(gin_file_str, skip_unknown)
  gin.parse_config(gin_bindings, skip_unknown)
  logging.info('Gin Configuration:\n%s', gin.config_str())

This code snippet defines a helper function, `parse_gin_and_get_configurable_fn`; this helper takes in a function that is configurable with gin, such as the fine-tuning `train` function. It then parses the parameters of our finetuning experiment to configure the function with Gin, and returns the fully configured function.

In [None]:
def parse_gin_and_get_configurable_fn(configurable_fn,
                                      gin_config_str="",
                                      gin_file_paths=[],
                                      gin_bindings=[],
                                      gin_search_paths=['.'],
                                      tfds_data_dir=None,
                                      seqio_additional_cache_dirs=[]):
  '''
  Args:
    configurable_fn: a function that is configurable with Gin. We will return an
      instance of this function that has been configured with the provided Gin 
      config.
    gin_config_str: a string representing the gin config file to be parsed.
    gin_file_paths: Path to gin configuration files. Multiple paths may be 
      passed and will be imported in the given order, with later configurations 
      overriding earlier ones.
    gin_bindings: Individual gin bindings. The `MODEL_DIR` gin_binding will be 
      added to this list.
    gin_search_paths: Comma-separated list of gin config path prefixes to be 
      prepended to suffixes given via `--gin_file`. Only the first prefix that 
      produces a valid path for each suffix will be used.
    tfds_data_dir: If set, this directory will be used to store datasets 
      prepared by TensorFlow Datasets that are not available in the public TFDS 
      GCS bucket. Note that this flag overrides the `tfds_data_dir` attribute of
      all `Task`s.
    seqio_additional_cache_dirs: Directories to search for cached Tasks in 
      addition to defaults.

    One of gin_config_str or gin_file_paths must be provided.
  '''
  if len(gin_config_str) == 0 and len(gin_file_paths) == 0:
    return ValueError("One of `gin_config_str` or `gin_file_paths` must be provided.")

  with googlelog.Capture():
    if tfds_data_dir:
      seqio.set_tfds_data_dir_override(tfds_data_dir)
    seqio.add_global_cache_dirs(seqio_additional_cache_dirs)

    # Create gin-configurable version of `evaluate`.
    fn_using_gin = gin.configurable(configurable_fn)

    default_gin_search_paths = [
      "/google_src/head/depot/google3/",
      "/google_src/head/depot/google3/t5x/",
      "/google_src/head/depot/google3/t5x/examples/t5/t5_1_1/small.gin",
      "/google_src/head/depot/google3/t5x/configs/runs/finetune.gin",
    ]
    # User-provided gin paths take precedence if relative paths conflict.
    gin_search_paths = gin_search_paths + default_gin_search_paths

    with gin.unlock_config():
      if len(gin_config_str) > 0:
        print("Parsing the provided gin_config_str and ignoring all provided gin file paths.")
        parse_gin_strings(
            gin_search_paths,
            gin_config_str,
            gin_bindings
        )
      else:
        print("Parsing the provided gin file paths and ignoring any provided gin_config_str.")
        gin_utils.parse_gin_flags(
            gin_search_paths,
            gin_file_paths,
            gin_bindings
        )
    return fn_using_gin



# Finetune a Model

This section outlines the steps to fine-tune an existing pre-trained model with T5X on common downstream tasks defined with [SeqIO](go/seqio). This is one of the simplest and most common use cases of T5X. 

Fine-tuning a model with T5X consists of the following steps:

1.   Choose the pre-trained model to fine-tune.
2.   Choose the SeqIO Task/Mixture to fine-tune the model on.
3.   Write a Gin file that configures the pre-trained model, SeqIO Task/Mixture and other details of your fine-tuning run.
4.   Launch your experiment locally.


These steps are explained in detail in the following sections. An example run that fine-tunes a T5-small checkpoint on WMT14 English to German translation benchmark is also showcased.



## Step 1: Choose a pre-trained model

To use a pre-trained model, you need a Gin config file that defines the model params, and the model checkpoint to load from. For your convenience, TensorFlow checkpoints and Gin configs for common T5 pre-trained models have been made available for use in T5X. A list of all the available pre-trained models (with model checkpoints and Gin config files) is available in the [Models](https://g3doc.corp.google.com/t5x/g3doc/models.md?cl=head) documentation.

For the example run, you will use the T5 1.1 Small model. The Gin file and checkpoint for this model are defined below:

In [None]:
GIN_FILE_PATH = "t5x/examples/t5/t5_1_1/small.gin"
MODEL_CHECKPOINT = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000"

## Step 2: Choose a SeqIO Task/Mixture

A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks and enables fine-tuning a model on multiple Tasks simultaneously.

### Standard Tasks
Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/), [SuperGLUE](https://super.gluebenchmark.com/), [WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate), [SQUAD](https://rajpurkar.github.io/SQuAD-explorer/), [CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been implemented as SeqIO Tasks/Mixtures and can be used directly. These Tasks/Mixtures are defined in [`third_party/py/t5/data/tasks.py`](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/tasks.py) and [`third_party/py/t5/data/mixtures.py`](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/mixtures.py).

For the example run, you will fine-tune the model on the WMT14 English to German translation benchmark, which has been implemented as the [`wmt_t2t_ende_v003`](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/tasks.py;l=209;rcl=417815592) Task.

In [None]:
SEQIO_TASK = 'wmt_t2t_ende_v003'

### Custom Tasks
It is also possible to define your own custom task. See the [SeqIO documentation](https://g3doc.corp.google.com/third_party/py/seqio/google/g3doc/index.md?cl=head) for how to do this. As a note, Tasks defined using the [old T5 codebase](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/dataset_providers.py) may also be used by T5X. If using a custom Task, you will need to follow the instructions in the "Advanced Topics" section at the end of this tutorial to make sure the module containing your task is included.

When defining a custom task, you have the option to cache it on disk before fine-tuning. The instructions for this are [here](https://g3doc.corp.google.com/third_party/py/seqio/google/g3doc/index.md?cl=head#optional-offline-caching). Caching may improve performance for tasks with expensive pre-processing. By default, T5X expects tasks to be cached. To finetune on a task that has not been cached, set `--gin.USE_CACHED_TASKS=False`.

## Step 3: Write a Gin Config


After choosing the pre-trained model and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. If you're not familiar with Gin, reading the [T5X Gin Primer](https://g3doc.corp.google.com/t5x/g3doc/usage/gin.md?cl=head) is recommended.

T5X provides a Gin file that configures the T5X trainer for fine-tuning (located at [`p5x/configs/runs/finetune.gin`](https://source.corp.google.com/piper///depot/google3/t5x/configs/runs/finetune.gin)), and expects a few params from you. These params can be specified in a separate Gin file, or via commandline flags. Below are the required params, with the values they should be set to for the example run.

In [None]:
# This is the path to the pre-trained checkpoint (from Step 1). 
INITIAL_CHECKPOINT_PATH = MODEL_CHECKPOINT

# This is the SeqIO Task or Mixture name to run (from Step 2). 
MIXTURE_OR_TASK_NAME = SEQIO_TASK

# Number of fine-tuning steps. This includes the number of steps that the model 
# was pre-trained for, so make sure to add the step number from the 
# INITIAL_CHECKPOINT_PATH. For the example run, to fine-tune for 20,000 steps, 
# set this to 1,020,000, since the initial checkpoint is the 1,000,000th step.
TRAIN_STEPS = 1020000

# This is a dict mapping feature key to maximum int length for that feature. 
# After preprocessing, features are truncated to the provided value. 
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}

# A path to write fine-tuned checkpoints to. When launching using XManager, this
# path is automatically set and can be accessed from the XManager Artifacts 
# page. When running locally using Colab/Blaze, you should explicitly pass a 
# directory. Launch commands are provided in the next step.
MODEL_DIR = '/tmp/t5x_pretrain'

# When fine-tuning a model that was pre-trained using Mesh Tensorflow (e.g. the 
# public T5 / mT5 / ByT5 models), this should be set to pretraining batch_size *
# pretrained target_token_length. Below are the recommended values for common 
# models.
# For T5 and T5.1.1: 2048 * 114. 
# For mT5: 1024 * 229. 
# For ByT5: 1024 * 189. 
# For MUM Base/Large/XL: 1024 * 256. 
# For MUM XXL: 1024 * 229. 
# For MUM Ace: 4096 * 229.
LOSS_NORMALIZING_FACTOR = 233472

In addition to the above params, you will need to include `finetune.gin` and the Gin file for the pre-trained model, which for the example run is `t5_1_1/small.gin`.

You will also need to import the Python module(s) that register SeqIO Tasks and Mixtures used in your run. For the example run, we add `import t5.data.tasks` since it is where `wmt_t2t_ende_v003` is registered.

Note that this module must also be included as a dependency in the T5X trainer [binary](https://source.corp.google.com/piper///depot/google3/t5x/BUILD;l=76;rcl=398627055) if you want to run your experiment via commandline/XManager; this is not necessary if you plan on training via Colab. Most common Task/Mixture modules, such as this one, are already included. If your module is not included, see the "Advanced Topics" section at the end of this tutorial for instructions to add it.

Finally, your Gin file should look like this:

In [None]:
sample_gin_config_str = f'''
include 't5x/configs/runs/finetune.gin'
include '{GIN_FILE_PATH}'

# Register necessary SeqIO Tasks/Mixtures.
import t5.data.tasks

MIXTURE_OR_TASK_NAME = "{MIXTURE_OR_TASK_NAME}"
TASK_FEATURE_LENGTHS = {TASK_FEATURE_LENGTHS}
TRAIN_STEPS = {TRAIN_STEPS}
DROPOUT_RATE = 0.0
INITIAL_CHECKPOINT_PATH = "{INITIAL_CHECKPOINT_PATH}"
LOSS_NORMALIZING_FACTOR = {LOSS_NORMALIZING_FACTOR}
MODEL_DIR = "{MODEL_DIR}"
'''

## Step 4: Launch your experiment

You can launch your experiment locally via the commandline or directly via Colab. 

To launch your experiment in Colab, run the following code snippet. For the example given, you can expect to see results in ~5 mins if you are using a DRAGONFISH_DONUT runtime.


This code snippet runs a helper function, `parse_gin_and_get_configurable_fn`; this helper takes in a function that is configurable with gin, such as the fine-tuning `train` function. It then parses the parameters of our finetuning experiment to configure the function with Gin. When run with `train` as the configurable function, it is equivalent to the T5X finetune script found in `t5x/train.py` (this is the script that will be run if you choose to train via the commandline, as described below).

In [None]:
# Call a helper function that returns a train function based on our experiment
# parameters.
run_finetuning = parse_gin_and_get_configurable_fn(
    train,
    gin_config_str=sample_gin_config_str)

In [None]:
# Launch experiment.
with googlelog.Capture():
  run_finetuning()

### On Commandline

You can perform equivalent finetuning directly on the commandline as well. Please see the [Fine-tuning a Model](https://g3doc.corp.google.com/t5x/g3doc/usage/finetune.md?cl=head) g3doc for further instructions.

# Evaluation

This section outlines the steps to evaluate a model with T5X on downstream tasks defined with SeqIO. Evaluating a model with T5X consists of the following steps:

1.)    Choose the model to evaluate.

2.)    Choose the SeqIO Task/Mixture to evaluate the model on.

3.)    Write a Gin file that configures the model, SeqIO Task/Mixture and other details of your eval run.

4.)    Launch your experiment locally or on XManager.

These steps are explained in detail in the following sections. We also provide an example where we evaluate the T5-1.1-Small model we fine-tuned above on the [WMT14 English to German translation benchmark](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/tasks.py;l=209;rcl=417815592).

## Step 1: Choose a Model

To evaluate a model, you need a Gin config file that defines the model params, and the model checkpoint to load from. 

For this example, we will use the same model that we fine-tuned in the previous section. Thus, we will use the same Gin config file to define our model and we will simply update our model checkpoint to use the checkpoint we just created during fine-tuning.

In [None]:
MODEL_CHECKPOINT = os.path.join(MODEL_DIR, "checkpoint_1020000")

## Step 2: Choose a SeqIO Task/Mixture

A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets.

We will evaluate our model on the same task we fine-tuned it on, the [WMT14 English to German translation benchmark](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/tasks.py;l=209;rcl=417815592).

In [None]:
SEQIO_TASK = 'wmt_t2t_ende_v003'

## Step 3: Write a Gin Config



T5X provides a Gin file that configures the T5X eval job (located at [t5x/configs/runs/eval.gin](https://source.corp.google.com/piper///depot/google3/t5x/configs/runs/eval.gin)), and expects a few params from you. These params can be specified in a separate Gin file, or via commandline flags. Below are the required params, with the values they should be set to for the example run.

In [None]:
# This is the path to the fine-tuned model checkpoint (from Step 1).
CHECKPOINT_PATH = MODEL_CHECKPOINT

# This is the SeqIO Task or Mixture name to run eval on (from Step 2). 
MIXTURE_OR_TASK_NAME = SEQIO_TASK

# A path to write eval outputs to. When launching using XManager, this path is 
# automatically set and can be accessed from the XManager Artifacts page. When 
# running locally using Colab/Blaze, you can explicitly pass a directory. Launch
#  commands are provided in the next step.
EVAL_OUTPUT_DIR = "/tmp/t5x_eval"

In addition to the above params, you will need to import `eval.gin` and the Gin file for the model, which for the example run is `t5_1_1_small.gin`.

You will also need to import the Python module(s) that register SeqIO Tasks and Mixtures used in your run. For the example run, we add `import t5.data.tasks` since it is where `wmt_t2t_ende_v003` is registered.

Note that this module must also be included as a dependency in the T5X trainer [binary](https://source.corp.google.com/piper///depot/google3/t5x/BUILD;l=76;rcl=398627055) if you want to run your experiment via commandline/XManager; this is not necessary if you plan on training via Colab. Most common Task/Mixture modules, such as this one, are already included. If your module is not included, see the "Advanced Topics" section at the end of this tutorial for instructions to add it.

Finally, your Gin file should look like this:

In [None]:
eval_gin_config_str = f'''
include 't5x/configs/runs/eval.gin'
include '{GIN_FILE_PATH}'

# Register necessary SeqIO Tasks/Mixtures.
import t5.data.tasks

CHECKPOINT_PATH = '{CHECKPOINT_PATH}'
MIXTURE_OR_TASK_NAME = '{MIXTURE_OR_TASK_NAME}'
DROPOUT_RATE=0
EVAL_OUTPUT_DIR=\'{EVAL_OUTPUT_DIR}\'
'''

In this example, we run the evaluation on one checkpoint. It is common to evaluate with multiple checkpoints. We provide an easy way to do so without having to recompile the model graph for each checkpoints. This is simply done by adding `utils.RestoreCheckpointConfig.mode = "all" ` to a gin file. Our `t5x/configs/runs/eval.gin` uses "specific" mode.

## Step 4: Launch your experiment

You can launch your experiment locally via the commandline or directly via Colab.

To launch your experiment in Colab, run the following code snippet. For the example given, you can expect to see results in ~1 min if you are using a DRAGONFISH_DONUT runtime.

This code snippet runs a helper function, `parse_gin_and_get_configurable_fn`; this helper takes in a function that is configurable with gin, such as the `evaluate` function. It then parses the parameters of our evaluation experiment to configure the function with Gin. When run with evaluate as the configurable function, it is equivalent to the T5X finetune script found in `t5x/eval.py` (this is the script that will be run if you choose to train via the commandline, as described below).

In [None]:
# Call a helper function that returns an evaluation function based on our experiment
# parameters.
run_evaluation = parse_gin_and_get_configurable_fn(evaluate,
    gin_config_str=eval_gin_config_str)

In [None]:
with googlelog.Capture():
  run_evaluation()

### On Commandline

You can perform equivalent evaluation directly on the commandline as well. Please see the [Evaluating a Model](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/eval#step-4-launch-your-experiment) g3doc for further instructions.





# Inference

T5X supports a few inference modes. Please refer to the appropriate tutorial based on your use-case:

1.)     Run inference on SeqIO Tasks/Mixtures.

2.)     Run inference on TF Example files.

## On SeqIO Task/Mixtures

To run inference on a model, you need a Gin config file that defines the model params, and the model checkpoint to load from. We will continue to use the same model as we used for evaluation. We provide the Gin config and model checkpoint for reference below.

In [None]:
GIN_FILE_PATH = "t5x/examples/t5/t5_1_1/small.gin"
MODEL_CHECKPOINT = os.path.join(MODEL_DIR, "checkpoint_1020000")

We must also provide a SeqIO Task/Mixture to run inference on; again, we will continue to use the [WMT14 English to German translation benchmark](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/tasks.py;l=209;rcl=417815592).

In [None]:
SEQIO_TASK = 'wmt_t2t_ende_v003'

After choosing the model and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. T5X provides a Gin file that configures the T5X inference job (located at `runs/infer.gin`) to run inference on SeqIO Task/Mixtures, and expects a few params from you. Below are the required params, with the values they should be set to for the example run.

In [None]:
# This is the path to the fine-tuned model checkpoint (from Step 1).
CHECKPOINT_PATH = MODEL_CHECKPOINT

# This is the SeqIO Task or Mixture name to run eval on (from Step 2). 
MIXTURE_OR_TASK_NAME = SEQIO_TASK
MIXTURE_OR_TASK_MODULE = 't5.data'

# This is a dict mapping feature key to maximum length for that feature. After 
# preprocessing, features are truncated to the provided value.
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}

# A path to write inference outputs to. When launching using XManager, this path is 
# automatically set and can be accessed from the XManager Artifacts page. When 
# running locally using Colab/Blaze, you can explicitly pass a directory. Launch
# commands are provided in the next step.
INFER_OUTPUT_DIR = "/tmp/t5x_infer"
# The directory needs to exist, if it doesn't already.
if not os.path.isdir(INFER_OUTPUT_DIR):
  os.mkdir(INFER_OUTPUT_DIR)


In addition to the above params, you will need to import `infer.gin` and the Gin file for the model, which for the example run is `t5_1_1_small.gin`.

Finally, your Gin file should look like this:

In [None]:
infer_gin_config_str = f'''
include 't5x/configs/runs/infer.gin'
include '{GIN_FILE_PATH}'

CHECKPOINT_PATH = '{CHECKPOINT_PATH}'
MIXTURE_OR_TASK_NAME = '{MIXTURE_OR_TASK_NAME}'
MIXTURE_OR_TASK_MODULE = '{MIXTURE_OR_TASK_MODULE}'
TASK_FEATURE_LENGTHS = {TASK_FEATURE_LENGTHS}
INFER_OUTPUT_DIR = '{INFER_OUTPUT_DIR}'
DROPOUT_RATE = 0
'''

You can launch your experiment locally via the commandline or directly via Colab.

To launch your experiment in Colab, run the following code snippet. 

This code snippet runs a helper function, `parse_gin_and_get_configurable_fn`; this helper takes in a function that is configurable with gin, such as the `infer` function. It then parses the parameters of our evaluation experiment to configure the function with Gin. When run with `infer` as the configurable function, it is equivalent to the T5X finetune script found in `t5x/infer.py` (this is the script that will be run if you choose to train via the commandline, as described below).

In [None]:
# Call a helper function that returns an evaluation function based on our experiment
# parameters.
run_inference = parse_gin_and_get_configurable_fn(infer,
    gin_config_str=infer_gin_config_str)

# Run inference.
with googlelog.Capture():
  run_inference()

You can perform equivalent evaluation directly on the commandline as well. Please see the [Inference on SeqIO Task/Mixtures](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/infer-seqio#step-4-launch-your-experiment) g3doc for further instructions.

## On TF Example Files



To run inference on a model, you need a Gin config file that defines the model params, and the model checkpoint to load from. We will continue to use the same model as we used for evaluation. We provide the Gin config and model checkpoint for reference below.

In [None]:
GIN_FILE_PATH = "t5x/examples/t5/t5_1_1/small.gin"
MODEL_CHECKPOINT = os.path.join(MODEL_DIR, "checkpoint_1020000")

T5X supports running inference on `tfrecord`, `recordio` and `sstable` files containing TF Examples. For the example run, you will run inference on `tfrecord` files containing the '`natural_questions_open`' dataset located here: `/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*`. Here's an example of a single row of data from this file (you can explore this file further using GQUI):

```
{ # (tensorflow.Example) size=101B
  features: { # (tensorflow.Features) size=99B
    feature: { # (tensorflow.Features.FeatureEntry) size=27B
      key: "answer" # size=6
      value: { # (tensorflow.Feature) size=17B
        bytes_list: { # (tensorflow.BytesList) size=15B
          value: [ "Jason Flemyng" ] # size=13
        } # features.feature[0].value.bytes_list
      } # features.feature[0].value
    } # features.feature[0]
    feature: { # (tensorflow.Features.FeatureEntry) size=68B
      key: "question" # size=8
      value: { # (tensorflow.Feature) size=56B
        bytes_list: { # (tensorflow.BytesList) size=54B
          value: [ "who played hyde in league of extraordinary gentlemen" ] # size=52
        } # features.feature[1].value.bytes_list
      } # features.feature[1].value
    } # features.feature[1]
  } # features
}
```

In [None]:
TF_RECORD_FILEPATH = '/path/to/tfds/data/dir/natural_questions_open/1.0.0/natural_questions_open-validation.tfrecord*'

After choosing the model and file source for your run, the next step is to configure your run using Gin. T5X provides a Gin file that configures the T5X inference job (located at `t5x/configs/runs/infer_from_tfexample_file.gin`) to run inference on TF Example files, and expects a few params from you. Below are the required params, with the values they should be set to for the example run.

In [None]:
# This is the path to the model checkpoint.
CHECKPOINT_PATH = MODEL_CHECKPOINT

# This is a list of paths or glob patterns to read TF Examples from. 
TF_EXAMPLE_FILE_PATHS = [TF_RECORD_FILEPATH]

# This is the TF Example file format. Currently supported file formats are 
# tfrecord, recordio and sstable.
TF_EXAMPLE_FILE_TYPE = 'tfrecord'

# This is a dict mapping feature key to maximum int length for that feature. The
# TF Example features are truncated to the provided value. 
FEATURE_LENGTHS = {'inputs': 38, 'targets': 18}

# A path to write inference outputs to. When launching using XManager, this path is 
# automatically set and can be accessed from the XManager Artifacts page. When 
# running locally using Colab/Blaze, you can explicitly pass a directory. Launch
# commands are provided in the next step.
INFER_OUTPUT_DIR = "/tmp/t5x_infer"
# The directory needs to exist, if it doesn't already.
if not os.path.isdir(INFER_OUTPUT_DIR):
  os.mkdir(INFER_OUTPUT_DIR)

MIXTURE_OR_TASK_NAME = infer_script.create_task_from_tfexample_file(paths=TF_EXAMPLE_FILE_PATHS, 
                                               file_type=TF_EXAMPLE_FILE_TYPE,
                                               inputs_key='questions', 
                                               targets_key=None,
                                               features=FEATURE_LENGTHS)


In addition to the above params, ou may also need to override the `create_task_from_tfexample_file.inputs_key` param based on the data format (it is set to 'inputs' by default. You will also need to import `infer_from_tfexample_file.gin` and the Gin file for the model, which for the example run is `t5_1_1_small.gin`.

For the purposes of this colab, we must make a few small edits to `infer_from_tfexample_file.gin`. Namely, we must replace all imports that refer to `__main__` with references to `t5x.infer`. We thus write `infer_from_tfexample_file.gin` to a temporary file where we can make these edits, using the helper function below.

In [None]:
def replace_main_in_config(config_filename, tmp_output_dir):
  '''
  Replace imports that refer to `__main__` with `t5x.infer`.
  '''
  original_gin_filepath = f'/google_src/head/depot/google3/t5x/configs/runs/{config_filename}'
  temporary_gin_filepath = os.path.join(tmp_output_dir, f"tmp_{config_filename}")
  with gfile.Open(original_gin_filepath, "rt") as original_gin_file:
    with open(temporary_gin_filepath, "w") as temporary_gin_file:
      for line in original_gin_file.readlines():
        line = re.sub("__main__", "t5x.infer", line)
        temporary_gin_file.write(line)
  return temporary_gin_filepath

temporary_gin_filepath = replace_main_in_config("infer_from_tfexample_file.gin", INFER_OUTPUT_DIR)

Finally, your Gin file should look like this:

In [None]:
infer_gin_config_str = f'''
include '{temporary_gin_filepath}'
include '{GIN_FILE_PATH}'

CHECKPOINT_PATH = '{CHECKPOINT_PATH}'
TF_EXAMPLE_FILE_PATHS = {TF_EXAMPLE_FILE_PATHS}
TF_EXAMPLE_FILE_TYPE = '{TF_EXAMPLE_FILE_TYPE}'
FEATURE_LENGTHS = {FEATURE_LENGTHS}
INFER_OUTPUT_DIR = '{INFER_OUTPUT_DIR}'
DROPOUT_RATE = 0
create_task_from_tfexample_file.inputs_key = 'question'
'''

You can launch your experiment locally via the commandline or directly via Colab.

To launch your experiment in Colab, run the following code snippet.

This code snippet runs a helper function, `parse_gin_and_get_configurable_fn`; this helper takes in a function that is configurable with gin, such as the infer function. It then parses the parameters of our evaluation experiment to configure the function with Gin. When run with `infer` as the configurable function, it is equivalent to the T5X finetune script found in `t5x/infer.py` (this is the script that will be run if you choose to train via the commandline, as described below).

In [None]:
# Call a helper function that returns an evaluation function based on our experiment
# parameters.
run_inference = parse_gin_and_get_configurable_fn(infer,
    gin_config_str=infer_gin_config_str)

# Run inference.
with googlelog.Capture():
  run_inference()

You can perform equivalent evaluation directly on the commandline as well. Please see the [Inference on TF Example Files](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/infer-files#step-4-launch-your-experiment) g3doc for further instructions.

# [Advanced] Pre-training from Scratch

Pretraining a model with T5X consists of the following steps:

1.)    Choose the model architecture.

2.)    Choose the SeqIO Task/Mixture to for training.

3.)    Write a Gin file that configures the model, SeqIO Task/Mixture and other details of your pretraining run.

4.)    Launch your experiment locally or on XManager.

These steps are explained in detail in the following sections. An example run that trains a T5 1.1 Small checkpoint from scratch on the C4 dataset using the span corruption pretraining objective is also showcased.

## Step 1: Choose a model architecture

To train a model, you need a Gin config file that defines the model params. For your convenience, Gin configs for common models have been made available for use in T5X. You can find a list of these models and their Gin locations in the [T5X g3doc](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/pretrain#step-1-choose-a-model-architecture).

For the example run, you will use the T5 1.1 Small model. The Gin file for this model is located at `t5x/examples/t5/t5_1_1/small.gin`.

In [None]:
GIN_FILE_PATH = 't5x/examples/t5/t5_1_1/small.gin'

## Step 2: Choose a SeqIO Task/Mixture

A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks and enables pretraining a model on multiple Tasks simultaneously.

For the example run, you will train the model on [`c4_v220_span_corruption`](https://source.corp.google.com/piper///depot/google3/third_party/py/t5/data/tasks.py;l=42;rcl=370153959) task that implements the span corruption pretraining objective using the C4 dataset. This is the final pretraining Task used in the T5 paper.

In [None]:
SEQIO_TASK = 'c4_v220_span_corruption'

## Step 3: Write a Gin Config

After choosing the model architecture and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. T5X provides a Gin file that configures the T5X trainer for pretraining (located at `runs/pretrain.gin`), and expects a few params from you. These params can be specified in a separate Gin file, or via commandline flags. Below are the required params, with the values they should be set to for the example run.

In [None]:
# Number of training steps.
TRAIN_STEPS = 100000

# This is the SeqIO Task or Mixture name to run (from Step 2).
MIXTURE_OR_TASK_NAME = SEQIO_TASK

# This is a dict mapping feature key to maximum int length for that feature. 
# After preprocessing, features are truncated to the provided value.
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}

# A path to write pretrained checkpoints to. When launching using XManager, this
# path is automatically set and can be accessed from the XManager Artifacts 
# page. When running locally using Colab/Blaze, you can explicitly pass a 
# directory. Launch commands are provided in the next step.
MODEL_DIR = "/tmp/pretrain-round-2"

In addition to the above params, you will need to import `pretrain.gin` and the Gin file for the pretrained model, which for the example run is `t5_1_1/small.gin`. You will also need to import the Python module(s) that register SeqIO Tasks and Mixtures used in your run. For the example run, we add `import t5.data.mixtures`.

Finally, your Gin file should look like this:

In [None]:
pretrain_gin_config_str = f'''
include '{GIN_FILE_PATH}'
include 't5x/configs/runs/pretrain.gin'

# Register necessary SeqIO Tasks/Mixtures.
import t5.data.mixtures

MIXTURE_OR_TASK_NAME = "{MIXTURE_OR_TASK_NAME}"
TASK_FEATURE_LENGTHS = {TASK_FEATURE_LENGTHS}
TRAIN_STEPS = {TRAIN_STEPS}
MODEL_DIR = "{MODEL_DIR}"
DROPOUT_RATE = 0.0
BATCH_SIZE = 256
partitioning.PjitPartitioner.num_partitions=2
'''

## Step 4: Launch your experiment

You can launch your experiment locally via the commandline or directly via Colab.

To launch your experiment in Colab, run the following code snippet. For the example given, you can expect to see results in ~X mins if you are using a DRAGONFISH_DONUT runtime.

This code snippet runs a helper function, `parse_gin_and_get_configurable_fn`; this helper takes in a function that is configurable with gin, such as the pre-tuning `train` function. It then parses the parameters of our pre-training experiment to configure the function with Gin. When run with `train` as the configurable function, it is equivalent to the T5X pre-training script found in `t5x/train.py` (this is the script that will be run if you choose to train via the commandline, as described below).

In [None]:
# Call a helper function that returns a train function based on our experiment
# parameters.
run_pretraining = parse_gin_and_get_configurable_fn(
    train,
    gin_config_str=pretrain_gin_config_str)

In [None]:
# Launch experiment.
with googlelog.Capture():
  run_pretraining()

In [None]:
import jax
print(jax.devices())

# Loading a Model

## With Gin

Let's define our Gin config and gin search paths.

In [None]:
# Gin Config
gin_config = """
include 't5x/examples/t5/t5_1_1/base.gin'  # imports vocab, optimizer and model.
DROPOUT_RATE = 0.0
# ------------------- Network specification overrides --------------------------
network.Transformer.config = @network.T5Config()
network.T5Config:
  emb_dim = 512
  num_heads = 6
  num_encoder_layers = 8
  num_decoder_layers = 8
  head_dim = 64
  mlp_dim = 1024
"""

gin_path = "/google_src/head/depot/google3/"
gin_search_files = [
    "",
    "t5x/",
    "t5x/examples/t5/t5_1_1/small.gin",
    "t5x/configs/runs/finetune.gin",
]

Next, let's parse this config.

In [None]:
for filepath in gin_search_files:
  abs_filepath = gin_path + filepath
  gin.add_config_file_search_path(abs_filepath)
  print(f"Added {abs_filepath} to the search list.")

gin.parse_config(gin_config)
gin.finalize()

Finally, let's test that we've parsed our gin file correctly by querying for some test parameters and loading our model.

In [None]:
print(gin.query_parameter('network.T5Config.emb_dim'))

In [None]:
model = gin.get_configurable("MODEL/macro")
print(model)

In [None]:
GIN_FILE_PATH = "t5x/examples/t5/t5_1_1/small.gin"
MODEL_CHECKPOINT = "gs://t5-data/pretrained_models/t5x/t5_1_1_small"

In [None]:
def restore_from_checkpoint(checkpoint_path):
    """Restore training state from checkpoint."""
    train_state_initializer = t5x.utils.TrainStateInitializer(
      optimizer_def=model.optimizer_def,
      init_fn=model.get_initial_variables,
      input_shapes=self.input_shapes,
      partitioner=self.partitioner)

    restore_checkpoint_cfg = t5x.utils.RestoreCheckpointConfig(
        path=checkpoint_path, mode='specific', dtype='float32')

    train_state_axes = train_state_initializer.train_state_axes
    self._predict_fn = self._get_predict_fn(train_state_axes)
    self._train_state = train_state_initializer.from_checkpoint_or_scratch(
        [restore_checkpoint_cfg], init_rng=jax.random.PRNGKey(0))


## Without Gin

In [None]:
t5_config = network.T5Config(
      vocab_size=32128,
      dtype='bfloat16',
      emb_dim=768,
      num_heads=12,
      num_encoder_layers=12,
      num_decoder_layers=12,
      head_dim=64,
      mlp_dim=2048,
      mlp_activations=('gelu', 'linear'),
      dropout_rate=0.0,
      logits_via_embedding=False)
module = network.Transformer(config=t5_config)

In [None]:
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8,step_offset=0),
    decode_fn=functools.partial(t5x.decoding.temperature_sample, temperature=1.0, topk=40)
)