
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This is the third Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed the [Introductory Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) and the [Training Deep Dive](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb), or have a basic understanding of the T5X models, checkpoints, partitioner, trainer, and `InteractiveModel`.

In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series, we dove into how the InteractiveModel restores models from checkpoints and runs training, while also getting an introduction to the T5X trainer. In this Colab, we will focus on how the `InteractiveModel` does decoding to generate predictions and scores for a given input. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `infer_with_preprocessors()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to run inference on a model.

# Set-Up

Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md).

In [None]:
from collections.abc import Sequence
import enum
import functools
import inspect
import itertools
import logging
import os
import re
from typing import Any, Callable, Iterator, Optional, Tuple, Union

import jax
from jax import random
from jax.experimental import multihost_utils
import numpy as np
import seqio
import tensorflow as tf
import tensorflow_datasets as tfds
import t5.data
import t5.data

In [None]:
import clu.data
from t5x.examples.t5 import network
import t5x
from t5x import models
from t5x import partitioning
from t5x import trainer as trainer_lib
from t5x import utils
from t5x.infer import _extract_tokens_and_aux_values
from t5x.infer import _Inferences
from t5x.interactive_model import InteractiveModel
from t5x.interactive_model import get_batches_from_seqio
from t5x.interactive_model import get_dataset_from_natural_text_examples
from t5x.interactive_model import get_gin_config_from_interactive_model
from t5x.interactive_model import T5XScriptType
from t5x.interactive_model import InferenceType

Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` runs inference.

If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series.

In [None]:
# Define a model. The configuration below corresponds to the T5 1.1 Small model.
t5_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False)
module = network.Transformer(config=t5_config)
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))
# Define checkpoint arguments.
checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
dtype='bfloat16'
restore_mode='specific'
# Define a partitioner.
partitioner=partitioning.PjitPartitioner(num_partitions=2)
# Define additional, miscellaneous constructor arguments.
batch_size=8
task_feature_lengths = {'inputs': 38, 'targets': 18}
output_dir='/tmp/output_dir'
input_shapes = {
    'encoder_input_tokens': np.array([8, 38]),
    'decoder_target_tokens': np.array([8, 18]),
    'decoder_input_tokens': np.array([8, 18]),
    'decoder_loss_weights': np.array([8, 18])
}

In addition, we will run all code that is performed when we initialize the InteractiveModel. If you don't understand the lines of code below or have any additional questions about how/why we do the steps below, please see the [second Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series.

In [None]:
# 1.) Configure the Output Directory
output_dir = re.sub(r"(?<!gs:)([\/]{2,})", "/", output_dir)
if not os.path.exists(output_dir):
  os.mkdir(output_dir)

# 2.) Initialize RNGs
init_random_seed = 42
random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))
utils.set_hardware_rng_ops()
rng = random.PRNGKey(random_seed)
init_rng, trainer_rng = random.split(rng, 2)

# 3.) Validate the Partitioner
if partitioner._model_parallel_submesh:
  num_partitions = np.prod(partitioner._model_parallel_submesh)
else:
  num_partitions = partitioner._num_partitions
if jax.device_count() % num_partitions != 0:
  raise ValueError(
    "The number of devices available must be a multiple of the number of",
    f" partitions. There are {jax.device_count()} devices available, but",
    f" the number of partitions is set to {num_partitions}. Please",
    " provide a different number of partitions.")

# 4.) Create a Checkpoint Manager
# a.) Define CheckpointCfg wrappers.
save_checkpoint_cfg = utils.SaveCheckpointConfig(
        dtype=dtype,
        keep=5, # The number of checkpoints to keep in the output_dir.
        save_dataset=False)
restore_checkpoint_cfg = utils.RestoreCheckpointConfig(
        dtype=dtype,
        mode=restore_mode,
        path=checkpoint_path)

# b.) Define a train state initializer, which will help us get information about the
# TrainState shape.
train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=model.optimizer_def,
        init_fn=model.get_initial_variables,
        input_shapes=input_shapes,
        input_types=None,
        partitioner=partitioner)

# c.) Define the checkpoint manager.
checkpoint_manager = utils.LegacyCheckpointManager(
        save_cfg=save_checkpoint_cfg,
        restore_cfg=restore_checkpoint_cfg,
        train_state_shape=train_state_initializer.global_train_state_shape,
        partitioner=partitioner,
        ds_iter=None,
        model_dir=output_dir)

### 5.) Restore the Model from a Checkpoint, or Initialize from Scratch ###
def get_state(rng):
  return train_state_initializer.from_scratch(rng).state_dict()

# a.) Try to restore a model from a checkpoint.
train_state = checkpoint_manager.restore(
  [restore_checkpoint_cfg.path],
  restore_checkpoint_cfg,
  utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)
)

# b.) If no checkpoint to restore, init from scratch.
if train_state is None:
  train_state = train_state_initializer.from_scratch(init_rng)

# Inference Deep Dive

**Defining a Batch of Examples to Run Inference On**\
Let's start by defining a batch of examples that we will get predictions and scores for.

These examples should be a list of inputs; we don't need any targets, because we will eventually generate predictions. For this Colab, we'll use a set of natural text questions (and we will generate the answers).

In [None]:
examples = [
    b'nq question: who has been appointed as the new chairman of sebi',
    b'nq question: who wrote the book lion the witch and the wardrobe',
    b'nq question: how many planes did japan lose at pearl harbor',
    b'nq question: who does the voice of mcgruff the dog',
    b'nq question: who sings the wheels in the sky keep on turning',
    b'nq question: who voices regina in glitter force doki doki',
    b'nq question: when did the us become allies with britain',
    b'nq question: who won the rugby 7 in las vegas'
]

We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below. `targets` will be empty for our examples, because we do not have any targets to provide at inference time.

In [None]:
output_features = {
        "inputs":
            seqio.Feature(
                vocabulary=model.input_vocabulary, add_eos=True),
        "targets":
            seqio.Feature(
                vocabulary=model.output_vocabulary, add_eos=True)
    }
features = dict(sorted(output_features.items()))

Finally, we'll have to determine whether we want to get predictions or scores for this batch. For this example, we'll get predictions, which we'll denote by setting an inference mode variable to `PREDICT_WITH_AUX`.

In [None]:
mode = InferenceType.PREDICT_WITH_AUX
# Try replacing this variable with `InferenceType.SCORE` to produce scores.

Now, let's break down what the interactive model does to run inference.

The `InteractiveModel` `infer_with_preprocessors()` method only performs three actions:


1.   Convert the natural text examples into a tf.Dataset.
2.   Define an `infer_fn`; depending on whether we want predictions or scores, this function will be equivalent to `model.predict_batch` or `model.score_batch`.
3.   Extract inferences and return them.



**Prepare the dataset** \

Preparing the data for inference is fairly straightforward; in fact, this is nearly the same data preparation that happens for training.

First, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro.

Finally, we  convert all features using the model's feature converter, pad all batches of data, and define an iterator over our data (this allows us to run inference on multiple batches of examples).

In [None]:
dataset = get_dataset_from_natural_text_examples(
    examples,
    preprocessors=[
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos
    ],
    task_feature_lengths=task_feature_lengths,
    features=features)
feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)
model_dataset = feature_converter(
    dataset, task_feature_lengths=task_feature_lengths)
# Zip task and model features.
infer_dataset = tf.data.Dataset.zip((dataset, model_dataset))
# Create batches and index them.
infer_dataset = infer_dataset.padded_batch(
    batch_size, drop_remainder=False).enumerate()
infer_dataset_iter: Iterator[Tuple[int, Any]] = iter(
    infer_dataset.prefetch(tf.data.experimental.AUTOTUNE))

**Define Infer Function** \

We'll define a helper function that runs inference on a single batch, making it easy to loop over this helper and run inference for multiple batches. This `infer_fn` can either get predictions or scores, depending on the mode we've previously set.



In [None]:
if mode == InferenceType.PREDICT_WITH_AUX:
  infer_step = model.predict_batch_with_aux
elif mode == InferenceType.SCORE:
  infer_step = model.score_batch
else:
  raise ValueError("Mode must be `predict_with_aux`, or `score`,"
                  f" but instead was {mode}.")
infer_fn = functools.partial(
  utils.get_infer_fn(
    infer_step=infer_step,
    batch_size=batch_size,
    train_state_axes=train_state_initializer.train_state_axes,
    partitioner=partitioner),
  train_state=train_state)

**Extract Inferences** \

Finally, we will extract inferences for each batch of examples provided. For each batch, we:

1.  Unzip the dataset to get both the task dataset and the model dataset (the model dataset is what you get when you've passed the task dataset through the model feature converter).
2.  Get an RNG for the batch.
3.  Extract predictions and auxiliary values using the T5X helper, `_extract_tokens_and_aux_values`.
4.  Decode the predictions using our vocabulary.
5.  Accumulate predictions, aux values, and inputs across all of our batches.



In [None]:
# Main Loop over "batches".
all_inferences = []
all_aux_values = {}
for chunk, chunk_batch in infer_dataset_iter:
  # Load the dataset for the next chunk. We can't use `infer_dataset_iter`
  # directly since `infer_fn` needs to know the exact size of each chunk,
  # which may be smaller for the final one.
  chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)
  chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)

  # Unzip chunk dataset in to pretokenized and model datasets.
  task_dataset = chunk_dataset.map(
      lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  model_dataset = chunk_dataset.map(
      lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  # Get a chunk-specific RNG key.
  chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)

  inferences = _extract_tokens_and_aux_values(
      infer_fn(model_dataset.enumerate(), rng=chunk_rng))

  predictions, aux_values = inferences
  accumulated_inferences = []
  for idx, inputs in task_dataset.enumerate().as_numpy_iterator():
    prediction = predictions[idx]
    # Decode predictions if applicable.
    if mode == InferenceType.PREDICT_WITH_AUX:
      prediction = features["targets"].vocabulary.decode_tf(
          tf.constant(prediction)).numpy()
    accumulated_inferences.append((inputs, prediction))
  all_inferences += accumulated_inferences
  # Accumulate aux values over batches.
  if not all_aux_values:
    all_aux_values = aux_values
  else:
    for key, values in aux_values.items():
      all_aux_values[key] += values
print(all_inferences)

We can parse these predictions into a more readable format using the code below.

In [None]:
for input, prediction in all_inferences:
  print(f"Input: {input['inputs_pretokenized']}")
  print(f"Prediction: {prediction}\n")

The code snippets above exactly replicate the `InteractiveModel` `infer_with_preprocessors()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.infer_with_preprocessors(mode, examples, preprocessors=[seqio.preprocessors.tokenize, seqio.preprocessors.append_eos])`.

# Advanced Topics

## T5X Inference Binaries and Other Advanced Features

T5X offers inference binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, inference on TF Example files, prediction services, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.

If you are familiar with Gin and interested in using the T5X inference binaries, we have provided a helper function, get_gin_config_from_interactive_model, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X inference binaries; this gin config will exactly reproduce the InteractiveModel inference functionality we've described above. We've provided an example below.

Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO.

In [None]:
!git clone https://github.com/google-research/google-research.git google_research

In [None]:
# Define an InteractiveModel instance, based on the `small` T5X EncoderDecoder model.
t5_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False)
module = network.Transformer(config=t5_config)
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
    decode_fn=functools.partial(
        t5x.decoding.temperature_sample, temperature=1.0, topk=40))
interactive_model = InteractiveModel(
    batch_size=8,
    task_feature_lengths={'inputs': 38, 'targets': 18},
    output_dir='/tmp/output_dir',
    partitioner=partitioning.PjitPartitioner(
      num_partitions=1,
      model_parallel_submesh=None,
      logical_axis_rules=partitioning.standard_logical_axis_rules()),
    model=model,
    dtype='bfloat16',
    restore_mode='specific',
    checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000',
    input_shapes={
      'encoder_input_tokens': np.array([8, 38]),
      'decoder_target_tokens': np.array([8, 18]),
      'decoder_input_tokens': np.array([8, 18]),
      'decoder_loss_weights': np.array([8, 18])
    },
    input_types=None)

# Define Gin Config strings for the model, partitioner, and any imports.
imports_str = """from t5x import models
from t5x import partitioning
import t5.data.mixtures
include 't5x/examples/t5/t5_1_1/small.gin'

# Register necessary SeqIO Tasks/Mixtures.
import google_research.t5_closed_book_qa.t5_cbqa.tasks"""
partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'

gin_config_str = get_gin_config_from_interactive_model(
  interactive_model=interactive_model,
  script_type=T5XScriptType.INFERENCE,
  task_name='closed_book_qa',
  partitioner_config_str=partitioner_config,
  model_config_str='',  # No config needed, since we just import the model.
  imports_str=imports_str,
)
print(gin_config_str)

Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your inference experiment locally by running the following on commandline:


```
INFER_OUTPUT_DIR="/tmp/inference-model/"
python -m t5x.infer \
  --gin_file=${GIN_FILE_PATH} \
  --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
  --alsologtostderr
```
For more details on inference using the T5X inference binaries, please see the [Inference](https://github.com/google-research/t5x/blob/main/docs/usage.md/infer-seqio) tutorial.

## SeqIO

If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.

We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.

<a href="https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below.

In [None]:
import google_research.t5_closed_book_qa.t5_cbqa.tasks
batches = get_batches_from_seqio(
        task_or_mixture_name='natural_questions_open',
        split='validation',
        batch_size=8,
        num_batches=2,
        seed=42)
print(f"Batches: {batches}")
# Train the interactive model on the provided batches.
original_step = interactive_model.step
_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")

The `get_batches_from_seqio` bridge can take several constructor arguments:


1.   `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).
2.   `split`: the split of the Task/Mixture to read data from.
3.   `batch_size`: how many examples should appear in each batch.
4.   `num_batches`: the total number of batches to return.
5.   `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \
The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \
However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.
6.   `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.
7.   `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.



In [None]:
import t5.data.tasks
batches = get_batches_from_seqio(
    task_or_mixture_name='c4_v220_span_corruption',
    split='validation',
    batch_size=8,
    num_batches=1,
    get_pretokenized_examples=False,
    sequence_length=interactive_model._task_feature_lengths,
    seed=42)
batch = batches[0]  # We expect only a single batch.
original_step = interactive_model.step
interactive_model.train_step_with_preprocessors(
        examples=batch, preprocessors=[])
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")