
<a href="https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This is the third Colab in a [series of tutorials on how to use T5X](https://github.com/google-research/t5x/blob/main/docs/tutorials.md). We assume that you have already completed the [Introductory Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) and the [Training Deep Dive](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb), or have a basic understanding of the T5X models, checkpoints, partitioner, trainer, and `InteractiveModel`.

In the [previous Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series, we dove into how the InteractiveModel restores models from checkpoints and runs training, while also getting an introduction to the T5X trainer. In this Colab, we will focus on how the `InteractiveModel` does decoding to generate predictions and scores for a given input. It should be noted that the code snippets below exactly replicate the InteractiveModel `__init__()` and `infer_with_preprocessors()` methods (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); we expose this functionality here in order to demonstrate how various components of the T5X codebase work together to run inference on a model.

# Set-Up

Note: If you are a using public colab, please use its `Connect to a local runtime` option by following the [setup guide](https://github.com/google-research/t5x/blob/main/t5x/notebooks/README.md).

In [None]:
!pip install t5

In [None]:
from collections.abc import Sequence
import enum
import functools
import inspect
import itertools
import logging
import os
import re
from typing import Any, Callable, Iterator, Optional, Tuple, Union

import jax
from jax import random
from jax.experimental import multihost_utils
import numpy as np
import seqio
import tensorflow as tf
import tensorflow_datasets as tfds
import t5.data
import t5.data

In [None]:
!git clone https://github.com/google/airio.git

In [None]:
import sys
import os

# Assuming you cloned it into /content/airio
REPO_PATH = 'airio'

# Add the main repository directory
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)

# If the actual Python source code (the 'airio' package itself) is in a subfolder,
# you may need to add that too. Check the structure of the cloned repo.
# Example: If the source is in /content/airio/src:
# SRC_PATH = os.path.join(REPO_PATH, 'src')
# if SRC_PATH not in sys.path:
#     sys.path.append(SRC_PATH)

print(f"Added {REPO_PATH} to sys.path.")

In [None]:
import sys
import os

# Assuming you cloned it into /content/airio
REPO_PATH = 't5x'

# Add the main repository directory
if REPO_PATH not in sys.path:
    sys.path.append(REPO_PATH)

# If the actual Python source code (the 'airio' package itself) is in a subfolder,
# you may need to add that too. Check the structure of the cloned repo.
# Example: If the source is in /content/airio/src:
# SRC_PATH = os.path.join(REPO_PATH, 'src')
# if SRC_PATH not in sys.path:
#     sys.path.append(SRC_PATH)

print(f"Added {REPO_PATH} to sys.path.")

In [None]:
import clu.data
from t5x.examples.t5 import network
import t5x
from t5x import models
from t5x import partitioning
from t5x import trainer as trainer_lib
from t5x import utils
from t5x.infer import _extract_tokens_and_aux_values
from t5x.infer import _Inferences
from t5x.interactive_model import InteractiveModel
from t5x.interactive_model import get_batches_from_seqio
from t5x.interactive_model import get_dataset_from_natural_text_examples
from t5x.interactive_model import get_gin_config_from_interactive_model
from t5x.interactive_model import T5XScriptType
from t5x.interactive_model import InferenceType

Before we begin, let's initialize instances of the constructor arguments for the `InteractiveModel`. As mentioned previously, this will enable us to dive into how the `InteractiveModel` runs inference.

If you don't understand the lines of code below, or have questions about how to initialize these parameters, please see the [first Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/introduction.ipynb) in this tutorial series.

In [None]:
# Define a model. The configuration below corresponds to the T5 1.1 Small model.
t5_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False)
module = network.Transformer(config=t5_config)
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))
# Define checkpoint arguments.
checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000'
dtype='bfloat16'
restore_mode='specific'
# Define a partitioner.
partitioner=partitioning.PjitPartitioner(num_partitions=2)
# Define additional, miscellaneous constructor arguments.
batch_size=8
task_feature_lengths = {'inputs': 38, 'targets': 18}
output_dir='/tmp/output_dir'
input_shapes = {
    'encoder_input_tokens': np.array([8, 38]),
    'decoder_target_tokens': np.array([8, 18]),
    'decoder_input_tokens': np.array([8, 18]),
    'decoder_loss_weights': np.array([8, 18])
}

In addition, we will run all code that is performed when we initialize the InteractiveModel. If you don't understand the lines of code below or have any additional questions about how/why we do the steps below, please see the [second Colab](https://colab.research.google.com/github/google-research/t5x/blob/main/t5x/notebooks/training.ipynb) in this tutorial series.

In [None]:
# 1.) Configure the Output Directory
output_dir = re.sub(r"(?<!gs:)([\/]{2,})", "/", output_dir)
if not os.path.exists(output_dir):
  os.mkdir(output_dir)

# 2.) Initialize RNGs
init_random_seed = 42
random_seed = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed))
utils.set_hardware_rng_ops()
rng = random.PRNGKey(random_seed)
init_rng, trainer_rng = random.split(rng, 2)

# 3.) Validate the Partitioner
if partitioner._model_parallel_submesh:
  num_partitions = np.prod(partitioner._model_parallel_submesh)
else:
  num_partitions = partitioner._num_partitions
if jax.device_count() % num_partitions != 0:
  raise ValueError(
    "The number of devices available must be a multiple of the number of",
    f" partitions. There are {jax.device_count()} devices available, but",
    f" the number of partitions is set to {num_partitions}. Please",
    " provide a different number of partitions.")

# 4.) Create a Checkpoint Manager
# a.) Define CheckpointCfg wrappers.
save_checkpoint_cfg = utils.SaveCheckpointConfig(
        dtype=dtype,
        keep=5, # The number of checkpoints to keep in the output_dir.
        save_dataset=False)
restore_checkpoint_cfg = utils.RestoreCheckpointConfig(
        dtype=dtype,
        mode=restore_mode,
        path=checkpoint_path)

# b.) Define a train state initializer, which will help us get information about the
# TrainState shape.
train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=model.optimizer_def,
        init_fn=model.get_initial_variables,
        input_shapes=input_shapes,
        input_types=None,
        partitioner=partitioner)

# c.) Define the checkpoint manager.
checkpoint_manager = utils.LegacyCheckpointManager(
        save_cfg=save_checkpoint_cfg,
        restore_cfg=restore_checkpoint_cfg,
        train_state_shape=train_state_initializer.global_train_state_shape,
        partitioner=partitioner,
        ds_iter=None,
        model_dir=output_dir)

### 5.) Restore the Model from a Checkpoint, or Initialize from Scratch ###
def get_state(rng):
  return train_state_initializer.from_scratch(rng).state_dict()

# a.) Try to restore a model from a checkpoint.
train_state = checkpoint_manager.restore(
  [restore_checkpoint_cfg.path],
  restore_checkpoint_cfg,
  utils.get_fallback_state(restore_checkpoint_cfg, get_state, init_rng)
)

# b.) If no checkpoint to restore, init from scratch.
if train_state is None:
  train_state = train_state_initializer.from_scratch(init_rng)

# Inference Deep Dive

**Defining a Batch of Examples to Run Inference On**\
Let's start by defining a batch of examples that we will get predictions and scores for.

These examples should be a list of inputs; we don't need any targets, because we will eventually generate predictions. For this Colab, we'll use a set of natural text questions (and we will generate the answers).

In [None]:
examples = [
    b'nq question: who has been appointed as the new chairman of sebi',
    b'nq question: who wrote the book lion the witch and the wardrobe',
    b'nq question: how many planes did japan lose at pearl harbor',
    b'nq question: who does the voice of mcgruff the dog',
    b'nq question: who sings the wheels in the sky keep on turning',
    b'nq question: who voices regina in glitter force doki doki',
    b'nq question: when did the us become allies with britain',
    b'nq question: who won the rugby 7 in las vegas'
]

We also define the required features of the examples. For this Colab, we will only require an `inputs` and `targets` entry, as defined below. `targets` will be empty for our examples, because we do not have any targets to provide at inference time.

In [None]:
output_features = {
        "inputs":
            seqio.Feature(
                vocabulary=model.input_vocabulary, add_eos=True),
        "targets":
            seqio.Feature(
                vocabulary=model.output_vocabulary, add_eos=True)
    }
features = dict(sorted(output_features.items()))

Finally, we'll have to determine whether we want to get predictions or scores for this batch. For this example, we'll get predictions, which we'll denote by setting an inference mode variable to `PREDICT_WITH_AUX`.

In [None]:
mode = InferenceType.PREDICT_WITH_AUX
# Try replacing this variable with `InferenceType.SCORE` to produce scores.

Now, let's break down what the interactive model does to run inference.

The `InteractiveModel` `infer_with_preprocessors()` method only performs three actions:


1.   Convert the natural text examples into a tf.Dataset.
2.   Define an `infer_fn`; depending on whether we want predictions or scores, this function will be equivalent to `model.predict_batch` or `model.score_batch`.
3.   Extract inferences and return them.



**Prepare the dataset** \

Preparing the data for inference is fairly straightforward; in fact, this is nearly the same data preparation that happens for training.

First, we convert the natural text examples into a tf.Dataset and run any preprocessors; T5X has a helper function, `get_dataset_from_natural_text_examples`, that can do exactly that. For this example, the only preprocessing we will do is tokenization and appending an EOS token. If you are interested in learning more about preprocessors, please take a look at https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx-colab-intro.

Finally, we  convert all features using the model's feature converter, pad all batches of data, and define an iterator over our data (this allows us to run inference on multiple batches of examples).

In [None]:
dataset = get_dataset_from_natural_text_examples(
    examples,
    preprocessors=[
        seqio.preprocessors.tokenize,
        seqio.preprocessors.append_eos
    ],
    task_feature_lengths=task_feature_lengths,
    features=features)
feature_converter = model.FEATURE_CONVERTER_CLS(pack=False)
model_dataset = feature_converter(
    dataset, task_feature_lengths=task_feature_lengths)
# Zip task and model features.
infer_dataset = tf.data.Dataset.zip((dataset, model_dataset))
# Create batches and index them.
infer_dataset = infer_dataset.padded_batch(
    batch_size, drop_remainder=False).enumerate()
infer_dataset_iter: Iterator[Tuple[int, Any]] = iter(
    infer_dataset.prefetch(tf.data.experimental.AUTOTUNE))

**Define Infer Function** \

We'll define a helper function that runs inference on a single batch, making it easy to loop over this helper and run inference for multiple batches. This `infer_fn` can either get predictions or scores, depending on the mode we've previously set.



In [None]:
if mode == InferenceType.PREDICT_WITH_AUX:
  infer_step = model.predict_batch_with_aux
elif mode == InferenceType.SCORE:
  infer_step = model.score_batch
else:
  raise ValueError("Mode must be `predict_with_aux`, or `score`,"
                  f" but instead was {mode}.")
infer_fn = functools.partial(
  utils.get_infer_fn(
    infer_step=infer_step,
    batch_size=batch_size,
    train_state_axes=train_state_initializer.train_state_axes,
    partitioner=partitioner),
  train_state=train_state)

**Extract Inferences** \

Finally, we will extract inferences for each batch of examples provided. For each batch, we:

1.  Unzip the dataset to get both the task dataset and the model dataset (the model dataset is what you get when you've passed the task dataset through the model feature converter).
2.  Get an RNG for the batch.
3.  Extract predictions and auxiliary values using the T5X helper, `_extract_tokens_and_aux_values`.
4.  Decode the predictions using our vocabulary.
5.  Accumulate predictions, aux values, and inputs across all of our batches.



In [None]:
# Main Loop over "batches".
all_inferences = []
all_aux_values = {}
for chunk, chunk_batch in infer_dataset_iter:
  # Load the dataset for the next chunk. We can't use `infer_dataset_iter`
  # directly since `infer_fn` needs to know the exact size of each chunk,
  # which may be smaller for the final one.
  chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)
  chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)

  # Unzip chunk dataset in to pretokenized and model datasets.
  task_dataset = chunk_dataset.map(
      lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  model_dataset = chunk_dataset.map(
      lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)

  # Get a chunk-specific RNG key.
  chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)

  inferences = _extract_tokens_and_aux_values(
      infer_fn(model_dataset.enumerate(), rng=chunk_rng))

  predictions, aux_values = inferences
  accumulated_inferences = []
  for idx, inputs in task_dataset.enumerate().as_numpy_iterator():
    prediction = predictions[idx]
    # Decode predictions if applicable.
    if mode == InferenceType.PREDICT_WITH_AUX:
      prediction = features["targets"].vocabulary.decode_tf(
          tf.constant(prediction)).numpy()
    accumulated_inferences.append((inputs, prediction))
  all_inferences += accumulated_inferences
  # Accumulate aux values over batches.
  if not all_aux_values:
    all_aux_values = aux_values
  else:
    for key, values in aux_values.items():
      all_aux_values[key] += values
print(all_inferences)

We can parse these predictions into a more readable format using the code below.

In [None]:
for input, prediction in all_inferences:
  print(f"Input: {input['inputs_pretokenized']}")
  print(f"Prediction: {prediction}\n")

The code snippets above exactly replicate the `InteractiveModel` `infer_with_preprocessors()` method (see [source code](https://github.com/google-research/t5x/blob/main/t5x/interactive_model.py)); running the code snippets above is exactly equivalent to running `interactive_model.infer_with_preprocessors(mode, examples, preprocessors=[seqio.preprocessors.tokenize, seqio.preprocessors.append_eos])`.

# Advanced Topics

## T5X Inference Binaries and Other Advanced Features

T5X offers inference binaries that have the same functionality as the InteractiveModel, with additional features as well (more advanced compiling, inference on TF Example files, prediction services, etc.). Importantly, these binaries are configured using [Gin](https://github.com/google/gin-config/blob/main/README.md); if you are not familiar with Gin, please take a look at this [Gin Primer](https://github.com/google-research/t5x/blob/main/docs/usage.md/gin) to get started.

If you are familiar with Gin and interested in using the T5X inference binaries, we have provided a helper function, get_gin_config_from_interactive_model, which will take an InteractiveModel instance and generate the gin config that you can use to run the T5X inference binaries; this gin config will exactly reproduce the InteractiveModel inference functionality we've described above. We've provided an example below.

Importantly, the InteractiveModel takes in a model, partitioner, and data, so we cannot generate Gin configs for these components. You can pass Gin config strings for the model and partitioner components to the helper function, as demonstrated below. Additionally, you can pass a SeqIO task containing your data to the helper function. See the section below if you are unfamiliar with SeqIO.

In [None]:
!git clone https://github.com/google-research/google-research.git google_research

In [None]:
# Define an InteractiveModel instance, based on the `small` T5X EncoderDecoder model.
t5_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=8,
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False)
module = network.Transformer(config=t5_config)
model = t5x.models.EncoderDecoderModel(
    module=module,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0),
    decode_fn=functools.partial(
        t5x.decoding.temperature_sample, temperature=1.0, topk=40))
interactive_model = InteractiveModel(
    batch_size=8,
    task_feature_lengths={'inputs': 38, 'targets': 18},
    output_dir='/tmp/output_dir',
    partitioner=partitioning.PjitPartitioner(
      num_partitions=1,
      model_parallel_submesh=None,
      logical_axis_rules=partitioning.standard_logical_axis_rules()),
    model=model,
    dtype='bfloat16',
    restore_mode='specific',
    checkpoint_path='gs://t5-data/pretrained_models/cbqa/small_ssm_nq/model.ckpt-1110000',
    input_shapes={
      'encoder_input_tokens': np.array([8, 38]),
      'decoder_target_tokens': np.array([8, 18]),
      'decoder_input_tokens': np.array([8, 18]),
      'decoder_loss_weights': np.array([8, 18])
    },
    input_types=None)

# Define Gin Config strings for the model, partitioner, and any imports.
imports_str = """from t5x import models
from t5x import partitioning
import t5.data.mixtures
include 't5x/examples/t5/t5_1_1/small.gin'

# Register necessary SeqIO Tasks/Mixtures.
import google_research.t5_closed_book_qa.t5_cbqa.tasks"""
partitioner_config = 'partitioning.PjitPartitioner.num_partitions = 2'

gin_config_str = get_gin_config_from_interactive_model(
  interactive_model=interactive_model,
  script_type=T5XScriptType.INFERENCE,
  task_name='closed_book_qa',
  partitioner_config_str=partitioner_config,
  model_config_str='',  # No config needed, since we just import the model.
  imports_str=imports_str,
)
print(gin_config_str)

Once you have generated the `gin_config_str` as above, you can write this string to a file and launch your inference experiment locally by running the following on commandline:


```
INFER_OUTPUT_DIR="/tmp/inference-model/"
python -m t5x.infer_unfragmented \
  --gin_file=${GIN_FILE_PATH} \
  --gin.INFER_OUTPUT_DIR=\"${INFER_OUTPUT_DIR}\" \
  --alsologtostderr
```
For more details on inference using the T5X inference binaries, please see the [Inference](https://github.com/google-research/t5x/blob/main/docs/usage.md/infer-seqio) tutorial.

## SeqIO

If you are interested in T5X, you may also be interested in, or have heard of, SeqIO. SeqIO is a library for processing sequential data to be fed into downstream sequence models. At a high level, SeqIO relies on user-defined `Tasks` and `Mixtures` that can be used to retrieve and evaluate datasets.

We won't go into details about SeqIO here; we recommend checking out this [SeqIO Introductory guide](https://github.com/google/seqio/blob/main/README.md/index) and/or clicking below to run a SeqIO Introductory Colab. The rest of this section will assume a basic understanding of SeqIO.

<a href="https://colab.research.google.com/github/google-research/seqio/blob/main/seqio/notebooks/Basics_Task_and_Mixtures.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

If you are already familiar with SeqIO and have a SeqIO task/mixture that you would like to use in this Colab, we do provide a SeqIO bridge that takes in a SeqIO task/mixture and produces batches of examples that can be processed by the code snippets above. We've provided an example of this bridge below.

In [None]:
import nest_asyncio
nest_asyncio.apply()

In [None]:
import google_research.t5_closed_book_qa.t5_cbqa.tasks
batches = get_batches_from_seqio(
        task_or_mixture_name='natural_questions_open',
        split='validation',
        batch_size=8,
        num_batches=2,
        seed=42)
print(f"Batches: {batches}")
# Train the interactive model on the provided batches.
original_step = interactive_model.step
_ = interactive_model.train_loop(num_steps=len(batches), train_batches=batches)
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")

The `get_batches_from_seqio` bridge can take several constructor arguments:


1.   `task_or_mixture_name`: the name of the SeqIO task/mixture to read data from. It should be noted that your task/mixture must already be registered with SeqIO, and you must import the module that defines your task/mixture here (as seen above).
2.   `split`: the split of the Task/Mixture to read data from.
3.   `batch_size`: how many examples should appear in each batch.
4.   `num_batches`: the total number of batches to return.
5.   `get_pretokenized_examples`: optional. A boolean, defaulting to True, that determines whether we should read the `inputs_pretokenized`/`targets_pretokenized` elements from an example, or the `inputs`/`targets` elements. \
The `train_step`, `predict`, `predict_with_aux`, `score`, and `evaluate` methods of the InteractiveModel assume that we should run [tokenization](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) and [appending an EOS token](https://github.com/google/seqio/tree/main/seqio/preprocessors.py) as the only preprocessors. To use these methods with this pre-defined list of preprocessors, you can set `get_pretokenized_examples=True` to retrieve examples that still need to be tokenized, and these InteractiveModel methods will handle running these preprocessors. This setting can also be helpful if you want to inspect the natural text inputs/targets of your SeqIO task. \
However, some SeqIO tasks do not use tokenization (ex: span corruption). You can set `get_pretokenized_examples=False`, and this bridge will read the fully preprocessed examples from the SeqIO task. You can then run `train_step_with_preprocessors`, `infer_with_preprocessors`, or `evaluate_with_preprocessors` and provide an empty preprocessors list (because all preprocessing has already been completed by this bridge) to run training/inference/evaluation. We have provided an example of using this bridge to retrieve fully preprocessed examples below.
6.   `sequence_length`: optional. A dictionary mapping feature key to maximum length (int) for that feature. Used by SeqIO to retrieve the dataset/examples.
7.   `**get_dataset_kwargs`: there are many [additional parameters](https://github.com/google/seqio/tree/main/seqio/dataset_providers.py) that can be set in the `SeqIO.get_dataset` function. If you would like to set any of these arguments, you can set them using this `kwargs` parameter.



In [None]:
import t5.data.tasks
batches = get_batches_from_seqio(
    task_or_mixture_name='c4_v220_span_corruption',
    split='validation',
    batch_size=8,
    num_batches=1,
    get_pretokenized_examples=False,
    sequence_length=interactive_model._task_feature_lengths,
    seed=42)
batch = batches[0]  # We expect only a single batch.
original_step = interactive_model.step
interactive_model.train_step_with_preprocessors(
        examples=batch, preprocessors=[])
print(f"Original Step: {original_step}, Current Step: {interactive_model.step}")

# Task
Explain how to adapt the provided T5X notebook for information retrieval by encoding documents on the fly and scoring them. Detail the necessary code modifications for the model, task, data processing, and inference mode.

## Understand information retrieval with t5x

### Subtask:
Briefly explain how T5X can be used for information retrieval, focusing on encoding and scoring documents.


**Reasoning**:
The subtask requires explaining how T5X can be used for information retrieval by encoding and scoring documents. This can be done by describing the encoding process to create embeddings and how the `score_batch` function is used for scoring. This explanation can be provided in a markdown cell.



T5X can be adapted for information retrieval by leveraging its ability to encode text into numerical representations, known as embeddings. In this approach, documents are passed through the T5X encoder to obtain fixed-size vector representations. These embeddings capture the semantic meaning of the documents.

For information retrieval, a query is also encoded into a similar embedding using the T5X encoder. The relevance between the query and a document can then be determined by calculating a similarity score between their respective embeddings. Common similarity measures include cosine similarity.

T5X provides functionalities like the `score_batch` function, which can be used to compute scores between input sequences and target sequences. In an information retrieval context, this function can be adapted to score the relevance of a document embedding with respect to a query embedding. While `score_batch` typically computes likelihoods, it can be modified or used in conjunction with other components to produce or utilize similarity scores for retrieval.

## Modified Code for Information Retrieval

Below are new cells containing the code modifications discussed for adapting this notebook to information retrieval. These cells replace or modify the functionality of the original cells to focus on encoding and scoring documents.

### Modified Model and Data Configuration

This cell modifies the model definition, input shapes, and task feature lengths to be more suitable for document encoding and scoring. **You may need to adjust the `t5_config` and model class based on the specific T5X model you intend to use for document encoding.**

In [None]:
# Define a model suitable for encoding documents.
# We'll use the T5 1.1 Small configuration as a starting point, configured as encoder-only.
t5_ir_config = network.T5Config(
    vocab_size=32128,
    dtype='bfloat16',
    emb_dim=512,
    num_heads=6,
    num_encoder_layers=8,
    num_decoder_layers=0, # Explicitly set to 0 for encoder-only
    head_dim=64,
    mlp_dim=1024,
    mlp_activations=('gelu', 'linear'),
    dropout_rate=0.0,
    logits_via_embedding=False) # Closing parenthesis added here

# Using Transformer module with encoder-only configuration
#module_ir = network.Transformer(config=t5_ir_config)

# Using EncoderDecoderModel but it functions as encoder-only due to config.
# Note: For pure encoder models, a dedicated EncoderOnlyModel class would be ideal
# if available in the T5X library. Using EncoderDecoderModel with num_decoder_layers=0
# is a common workaround, but might still have compatibility issues with full checkpoints.
model_ir = t5x.models.EncoderDecoderModel(
    module=module_ir,
    input_vocabulary=t5.data.get_default_vocabulary(),
    output_vocabulary=t5.data.get_default_vocabulary(),
    optimizer_def=t5x.adafactor.Adafactor(decay_rate=0.8, step_offset=0))


# Define checkpoint arguments for your IR model.
# Using a publicly available T5 1.1 Small checkpoint as an example.
# IMPORTANT: This checkpoint is for a full T5 model. Loading it into an
# encoder-only configuration might require careful handling of the state
# dictionary to only load encoder parameters. If this causes issues, you
# might need a checkpoint specifically trained for an encoder-only model
# or implement custom state loading logic.
checkpoint_path_ir = 'gs://t5-data/pretrained_models/t5x/retrieval/gtr_base/checkpoint_1819900/'
dtype_ir = 'bfloat16'
restore_mode_ir = 'specific'

# Define a partitioner. Adjust based on your setup.
partitioner_ir = partitioning.PjitPartitioner(num_partitions=2) # Adjust num_partitions

# Define batch size and task feature lengths for your document data.
# Adjust these based on the maximum sequence length of your documents.
batch_size_ir = 8 # Adjust batch size
# Adjust input length based on expected document length. Targets can be minimal for scoring.
# For encoding, only 'inputs' is strictly necessary for the model input shape.
# However, the InteractiveModel and feature converter might expect 'targets' as well.
document_feature_lengths = {'inputs': 128, 'targets': 1}

# Define input shapes for your IR model.
# These shapes are based on the model's expected inputs for the forward pass.
input_shapes_ir = {
    'encoder_input_tokens': np.array([batch_size_ir, document_feature_lengths['inputs']]),
    'decoder_target_tokens': np.array([batch_size_ir, document_feature_lengths['targets']]), # Needed for InteractiveModel structure
    'decoder_input_tokens': np.array([batch_size_ir, document_feature_lengths['targets']]), # Needed for InteractiveModel structure
    'decoder_loss_weights': np.array([batch_size_ir, document_feature_lengths['targets']]) # Needed for InteractiveModel structure
}

### Modified InteractiveModel Initialization

This cell re-initializes the `InteractiveModel` with the configuration for information retrieval.

In [None]:
# Re-initialize InteractiveModel with IR configuration
interactive_model_ir = InteractiveModel(
    batch_size=batch_size_ir,
    task_feature_lengths=document_feature_lengths,
    output_dir='/tmp/ir_output_dir', # Consider a new output directory
    partitioner=partitioner_ir,
    model=model_ir,
    dtype=dtype_ir,
    restore_mode=restore_mode_ir, # This will be set to 'allow_partial' below
    checkpoint_path=checkpoint_path_ir,
    input_shapes=input_shapes_ir,
    input_types=None)

# Initialize RNGs and Checkpoint Manager for the IR model
output_dir_ir = re.sub(r"(?<!gs:)([\/]{2}{2})", "/", '/tmp/ir_output_dir')
if not os.path.exists(output_dir_ir):
  os.mkdir(output_dir_ir)

init_random_seed_ir = 42
random_seed_ir = multihost_utils.broadcast_one_to_all(np.int32(init_random_seed_ir))
utils.set_hardware_rng_ops() # This might be redundant if already called
rng_ir = random.PRNGKey(random_seed_ir)
init_rng_ir, _ = random.split(rng_ir, 2) # We only need init_rng for restoring

save_checkpoint_cfg_ir = utils.SaveCheckpointConfig(
        dtype=dtype_ir,
        keep=5,
        save_dataset=False)
restore_checkpoint_cfg_ir = utils.RestoreCheckpointConfig(
        dtype=dtype_ir,
        mode='allow_partial', # Changed restore_mode to 'allow_partial'
        path=checkpoint_path_ir)

train_state_initializer_ir = utils.TrainStateInitializer(
        optimizer_def=model_ir.optimizer_def,
        init_fn=model_ir.get_initial_variables,
        input_shapes=input_shapes_ir,
        input_types=None,
        partitioner=partitioner_ir)

checkpoint_manager_ir = utils.LegacyCheckpointManager(
        save_cfg=save_checkpoint_cfg_ir,
        restore_cfg=restore_checkpoint_cfg_ir,
        train_state_shape=train_state_initializer_ir.global_train_state_shape,
        partitioner=partitioner_ir,
        ds_iter=None,
        model_dir=output_dir_ir)

def get_state_ir(rng):
  return train_state_initializer_ir.from_scratch(rng).state_dict()

# a.) Try to restore a model from a checkpoint.
# Use the modified restore_checkpoint_cfg_ir with 'allow_partial'
train_state_ir = checkpoint_manager_ir.restore(
  [restore_checkpoint_cfg_ir.path],
  restore_checkpoint_cfg_ir,
  utils.get_fallback_state(restore_checkpoint_cfg_ir, get_state_ir, init_rng_ir)
)


# b.) If no checkpoint to restore, init from scratch.
if train_state_ir is None:
  print("Warning: No checkpoint found for IR model. Initializing from scratch.")
  train_state_ir = train_state_initializer_ir.from_scratch(init_rng_ir)

### Prepare Document Data for Scoring

This cell shows how to prepare your document data and queries for scoring. **You will need to replace the placeholder data with your actual document and query loading and preprocessing logic.** This might involve reading text files, tokenizing, and formatting them according to your model's requirements.

### Conceptual Code for Loading Only Encoder Parameters

This cell provides a conceptual example of how you might use `state_transformation_fns` to load only the encoder parameters from a full T5 checkpoint. **This is an advanced technique and the code provided is a template that will likely need significant adaptation based on the specific structure of your model and the checkpoint.**

You would typically use this with a more modern checkpointing utility like `orbax.checkpoint.Checkpointer` instead of `LegacyCheckpointManager`.

In [None]:
import jax
import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict
from flax.core.frozen_dict import FrozenDict

# This is a conceptual example of a state transformation function.
# You will need to adapt this based on the exact structure of your checkpoint
# and your encoder-only model's expected state dictionary.
def remove_decoder_params(source_state_dict):
    """Removes decoder parameters from a full T5 state dictionary."""
    # Flatten the dictionary for easier manipulation
    flat_state_dict = flatten_dict(source_state_dict)

    # Create a new dictionary with only encoder-related keys
    # The exact key paths will depend on the checkpoint structure
    encoder_flat_state_dict = {
        key: value for key, value in flat_state_dict.items()
        if key[0] == 'encoder' # Assuming 'encoder' is the top-level key for encoder params
    }

    # Unflatten the dictionary back to its nested structure
    # This might require recreating the expected nested structure for the encoder
    # based on your model definition. This part is complex and highly dependent
    # on your model.
    # This is a simplified example and might not work directly.
    transformed_state_dict = unflatten_dict(encoder_flat_state_dict)

    # You might need additional logic here to ensure the structure matches
    # the expected state dictionary of your encoder-only model.

    return transformed_state_dict

# Example usage with a hypothetical Checkpointer (not directly usable with LegacyCheckpointManager)
# Assuming you have a Checkpointer instance and a checkpoint path:
# checkpointer = orbax.checkpoint.Checkpointer(...)
# checkpoint_path = 'gs://t5-data/pretrained_models/t5_1_1_small/checkpoint_1000000'
# target_state = train_state_initializer_ir.from_scratch(init_rng_ir).state_dict() # Get the structure of your target state

# try:
#     restored_state = checkpointer.restore(
#         checkpoint_path,
#         item=target_state, # Provide the target state structure
#         state_transformation_fns={'params': remove_decoder_params} # Apply the transformation
#     )
# except Exception as e:
#     print(f"Error during restoration with transformation: {e}")
#     restored_state = None

# If successful, restored_state would contain only the loaded encoder parameters.

In [None]:
# Replace with your actual document data and queries
# This is a placeholder
documents = [
    b'document 1: This is the content of the first document.',
    b'document 2: The second document contains different information.',
    b'document 3: A third document for testing purposes.',
    b'document 4: Yet another document.',
    b'document 5: Content for the fifth document.',
    b'document 6: More text for document six.',
    b'document 7: Document number seven.',
    b'document 8: The eighth document content.'
]

query = b'query: information about documents' # Replace with your query

# Combine query and documents for scoring batch.
# The exact structure will depend on how your model expects input for scoring.
# A common approach is to concatenate query and document representations.
# Here's a conceptual example:
# You'll likely need a custom preprocessing function here.
# This is a simplified representation.
scoring_examples = []
for doc in documents:
    # Create example pairs of (query, document) for scoring
    # The exact format depends on your model's scoring input requirement
    # For a simple scoring setup using EncoderDecoderModel's score_batch:
    # Inputs could be the query, and targets could be the document.
    scoring_examples.append({'inputs': query, 'targets': doc}) # Ensure keys are 'inputs' and 'targets'

# Print a sample of scoring_examples to inspect the structure
print("Sample scoring_examples structure:", scoring_examples[:2])


# Define output features for scoring.
# This might need adjustment based on your model's scoring output.
output_features_ir = {
        "inputs":
            seqio.Feature(
                vocabulary=model_ir.input_vocabulary, add_eos=True),
        "targets": # Targets here represent the documents being scored against the query
            seqio.Feature(
                vocabulary=model_ir.output_vocabulary, add_eos=True)
    }
features_ir = dict(sorted(output_features_ir.items()))

# Define preprocessors for your document and query data.
# These should align with how your chosen IR model was trained.
preprocessors_ir = [
    seqio.preprocessors.tokenize,
    seqio.preprocessors.append_eos
    # Add other relevant preprocessors (e.g., truncation, padding)
]

# Convert examples to dataset
# You might need to adapt this part significantly for your data source
dataset_ir = get_dataset_from_natural_text_examples(
    scoring_examples, # Use your scoring examples
    preprocessors=preprocessors_ir,
    task_feature_lengths=document_feature_lengths,
    features=features_ir)

feature_converter_ir = model_ir.FEATURE_CONVERTER_CLS(pack=False)
model_dataset_ir = feature_converter_ir(
    dataset_ir, task_feature_lengths=document_feature_lengths)

infer_dataset_ir = tf.data.Dataset.zip((dataset_ir, model_dataset_ir))

infer_dataset_ir = infer_dataset_ir.padded_batch(
    batch_size_ir, drop_remainder=False).enumerate()
infer_dataset_iter_ir: Iterator[Tuple[int, Any]] = iter(
    infer_dataset_ir.prefetch(tf.data.experimental.AUTOTUNE))

### Define and Run Scoring Function

This cell sets the inference mode to scoring and runs the scoring function on your document data and query.

In [None]:
# Set inference mode to SCORE
mode_ir = InferenceType.SCORE

if mode_ir == InferenceType.PREDICT_WITH_AUX:
  infer_step_ir = model_ir.predict_batch_with_aux
elif mode_ir == InferenceType.SCORE:
  infer_step_ir = model_ir.score_batch
else:
  raise ValueError("Mode must be `predict_with_aux`, or `score`,"
                  f" but instead was {mode_ir}.")

infer_fn_ir = functools.partial(
  utils.get_infer_fn(
    infer_step=infer_step_ir,
    batch_size=batch_size_ir,
    train_state_axes=train_state_initializer_ir.train_state_axes,
    partitioner=partitioner_ir),
  train_state=train_state_ir)

# Run scoring
all_scores = []
for chunk, chunk_batch in infer_dataset_iter_ir:
    chunk_dataset = tf.data.Dataset.from_tensor_slices(chunk_batch)
    chunk_dataset.cache().prefetch(tf.data.experimental.AUTOTUNE)

    task_dataset = chunk_dataset.map(
        lambda p, m: p, num_parallel_calls=tf.data.experimental.AUTOTUNE)
    model_dataset = chunk_dataset.map(
        lambda p, m: m, num_parallel_calls=tf.data.experimental.AUTOTUNE)

    chunk_rng = jax.random.fold_in(jax.random.PRNGKey(0), chunk)

    # In scoring mode, _extract_tokens_and_aux_values will return None for predictions
    # and the scores in aux_values.
    inferences_ir = _extract_tokens_and_aux_values(
        infer_fn_ir(model_dataset.enumerate(), rng=chunk_rng))

    predictions_ir, aux_values_ir = inferences_ir

    # Extract scores from aux_values
    if 'scores' not in aux_values_ir:
        raise ValueError("Expected 'scores' in aux_values when mode is SCORE.")

    scores_batch = aux_values_ir['scores']
    all_scores.extend(scores_batch)

# Associate scores with documents
scored_documents = []
for i, score in enumerate(all_scores):
    # Assuming a one-to-one correspondence between scoring_examples and scores
    scored_documents.append({'document': scoring_examples[i]['targets'], 'score': score})

# Sort documents by score in descending order
scored_documents.sort(key=lambda x: x['score'], reverse=True)

print("Documents ranked by relevance to the query:")
for item in scored_documents:
    print(f"Document: {item['document']}, Score: {item['score']}\n")

### Next Steps

You've now seen the core modifications needed to adapt the notebook for information retrieval using scoring.

To fully implement this, you will need to:

1.  **Choose and configure a T5X model** suitable for your information retrieval task. This might involve using a pre-trained model or fine-tuning on a relevant dataset. Update the model definition in the "Modified Model and Data Configuration" cell.
2.  **Implement your data loading and preprocessing logic** in the "Prepare Document Data for Scoring" cell. This is crucial for handling your specific document data and formatting it correctly for your chosen T5X model.
3.  **Refine the scoring and ranking logic** in the "Define and Run Scoring Function" cell based on how your model outputs scores and how you want to rank the results.

Let me know if you'd like to explore any of these steps in more detail!