##### Copyright 2020 The T5 Authors

Licensed under the Apache License, Version 2.0 (the "License");

In [0]:
# Copyright 2019 The T5 Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

# Fine-Tuning the Text-To-Text Transfer Transformer (T5)
## _Or: What does T5 know?_

*The following tutorial guides you through the process of fine-tuning a pre-trained T5 model, evaluating its accuracy, and using it for prediction,
all on a free Google Cloud TPU <a href="https://colab.research.google.com/github/google-research/text-to-text-transfer-transformer/blob/master/notebooks/t5-trivia.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>.*

### Background

T5 was introduced in the paper [_Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer_](https://arxiv.org/abs/1910.10683). In that paper, we provided a comprehensive picture of how we pre-trained a standard text-to-text Transformer model on a large text corpus, achieving state-of-the-art results on many NLP tasks after fine-tuning.

We pre-trained T5 on a mixture of supervised and unsupervised tasks with the majoriy of data coming from an unlabeled dataset we developed called [C4](https://www.tensorflow.org/datasets/catalog/c4). C4 is based on a massive scrape of the web produced by [Common Crawl](https://commoncrawl.org). Loosely speaking, pre-training on C4 ideally gives T5 an understanding of natural language in addition to general world knowledge.

### How can we assess what T5 knows?

As the name implies, T5 is a text-to-text model, which enables us to train it on arbitrary tasks involving a textual input and output. As we showed in our paper, a huge variety of NLP tasks can be cast in this format, including translation, summarization, and even classification and regression tasks.

One way to use this text-to-text framework is on reading comprehension problems, where the model is fed some context along with a question and is trained to predict the question's answer. For example, we might feed the model the text from the Wikipedia article about [Hurrican Connie](https://en.wikipedia.org/wiki/Hurricane_Connie) along with the question "On what date did Hurricane Connie occur?" and train the model to predict the answer "August 3rd, 1955".
A related task is open-domain question answering (QA) where the model is not provided with this oracle context. Typically, open-domain QA systems include a mechanism to look up information in an external knowledge source. This setting is similar to an "open-book" exam.


### Caveats

* While we provide instructions for running on a [Cloud TPU](https://cloud.google.com/tpu/) via Colab for free, a [Google Cloud Storage (GCS)](http://console.cloud.google.com/storage) bucket is required for storing model parameters and data. The [GCS free tier](https://cloud.google.com/free/) provides 5 GB of storage, which should be enough to train the `large` model and smaller but not the `3B` or `11B` parameter models. You can use part of your initial $300 credit to get more space.
* The Cloud TPU provided by Colab (a `v2-8`) does not have enough memory to fine-tune the `11B` parameter model. For this model, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).


# Set Up

<h3><a href="https://cloud.google.com/tpu/"><img valign="middle" src="https://raw.githubusercontent.com/GoogleCloudPlatform/tensorflow-without-a-phd/master/tensorflow-rl-pong/images/tpu-hexagon.png" width="50"></a>  &nbsp;&nbsp;Train on TPU</h3>




   1. Create a Cloud Storage bucket for your data and model checkpoints at http://console.cloud.google.com/storage, and fill in the `BASE_DIR` parameter in the following form. There is a [free tier](https://cloud.google.com/free/) if you do not yet have an account.
 
   1. On the main menu, click Runtime and select **Change runtime type**. Set "TPU" as the hardware accelerator.
   1. Run the following cell and follow instructions to:
    *  Set up a Colab TPU running environment
    *   Verify that you are connected to a TPU device
    *   Upload your credentials to TPU to access your GCS bucket


In [0]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install -q t5

import functools
import os
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5

BASE_DIR = "gs://t5_nlp_arc" #@param { type: "string" }
if not BASE_DIR or BASE_DIR == "gs://":
  raise ValueError("You must enter a BASE_DIR.")
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")
ON_CLOUD = True


if ON_CLOUD:
  print("Setting up GCS access...")
  import tensorflow_gcs_config
  from google.colab import auth
  # Set credentials for GCS reading/writing from Colab and TPU.
  TPU_TOPOLOGY = "2x2"
  try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
    TPU_ADDRESS = tpu.get_master()
    print('Running on TPU:', TPU_ADDRESS)
  except ValueError:
    raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')
  auth.authenticate_user()
  tensorflow_gcs_config.configure_gcs_from_colab_auth('/device:CPU:0')
  tf.config.experimental_connect_to_host(TPU_ADDRESS)
  tensorflow_gcs_config.configure_gcs_from_colab_auth()

tf.disable_v2_behavior()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging

if ON_CLOUD:
  tf.get_logger().propagate = False
  py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

Installing dependencies...
[K     |████████████████████████████████| 143kB 3.4MB/s 
[K     |████████████████████████████████| 61kB 7.5MB/s 
[K     |████████████████████████████████| 8.3MB 37.2MB/s 
[K     |████████████████████████████████| 1.0MB 46.4MB/s 
[K     |████████████████████████████████| 3.3MB 34.5MB/s 
[K     |████████████████████████████████| 573kB 51.3MB/s 
[K     |████████████████████████████████| 296kB 46.2MB/s 
[K     |████████████████████████████████| 421.8MB 37kB/s 
[K     |████████████████████████████████| 3.7MB 35.1MB/s 
[K     |████████████████████████████████| 890kB 39.8MB/s 
[K     |████████████████████████████████| 3.9MB 35.9MB/s 
[K     |████████████████████████████████| 450kB 40.5MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Building wheel for gast (setup.py) ... [?25l[?25hdone
[31mERROR: tensorflow-probability 0.10.0rc0 has requirement gast>=0.3.2, but you'll have gast 0.2.2 which is incompatible.[0m
[31mERROR: 

# Creating new Tasks and Mixture

Two core components of the T5 library are `Task` and `Mixture` objects.

A `Task` is a dataset along with preprocessing functions and evaluation metrics. A `Mixture` is a collection of `Task` objects along with a mixing rate or a function defining how to compute a mixing rate based on the properties of the constituent `Tasks`.

For this example, we will fine-tune the model to do closed-book question answering.

### ARC

Since the raw data splits are stored as JSONL files, we will first need to convert them to TSV format to make them parseable in TensorFlow. We will also take the opportunity to drop information we will not be using, remove questions with multiple answers, and to do a bit of cleaning of the text.

In [0]:
import gzip
import json

DATA_DIR = "gs://t5_nlp_arc/data_arc_challenge_update"

# Public directory of Natural Questions data on GCS.
NQ_JSONL_DIR = "gs://t5_nlp_arc/data_arc_challenge_update/"
NQ_SPLIT_FNAMES = {
    "train": "arcch_train.jsonl",
    "validation": "arcch_test.jsonl"
}
nq_counts_path = os.path.join(DATA_DIR, "arc-counts.json")
nq_tsv_path = {
    "train": os.path.join(DATA_DIR, "arc-train.tsv"),
    "validation": os.path.join(DATA_DIR, "arc-validation.tsv")
}
print(DATA_DIR)

def nq_jsonl_to_tsv(in_fname, out_fname):

  count = 0
  with tf.io.gfile.GFile(in_fname, "rb") as infile,\
       tf.io.gfile.GFile(out_fname, "w") as outfile:
    for line in infile:
      ex = json.loads(line)
      question = ex["input_text"]
      answer = ex['target_text']
      # Write this line as <question>\t<answer>
      outfile.write("%s\t%s\n" % (question, answer))
      count += 1
      tf.logging.log_every_n(
          tf.logging.INFO,
          "Wrote %d examples to %s." % (count, out_fname),
          1000)
    return count

if tf.io.gfile.exists(nq_counts_path):
  # Used cached data and counts.
  tf.logging.info("Loading NQ from cache.")
  num_nq_examples = json.load(tf.io.gfile.GFile(nq_counts_path))
else:
  # Create TSVs and get counts.
  tf.logging.info("Generating NQ TSVs.")
  num_nq_examples = {}
  for split, fname in NQ_SPLIT_FNAMES.items():
    num_nq_examples[split] = nq_jsonl_to_tsv(
        os.path.join(NQ_JSONL_DIR, fname), nq_tsv_path[split])
  json.dump(num_nq_examples, tf.io.gfile.GFile(nq_counts_path, "w"))

gs://t5_nlp_arc/data_arc_challenge_update
INFO:tensorflow:Generating NQ TSVs.
INFO:tensorflow:Wrote 1 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 1001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 2001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 3001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 4001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 5001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 6001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 7001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 8001 examples to gs://t5_nlp_arc/data_arc_challenge_update/arc-train.tsv.
INFO:tensorflow:Wrote 9001 examples to gs://t5_nlp_arc/data_a

Next, we define a function to load the TSV data as a `tf.data.Dataset` in TensorFlow.

In [0]:
def nq_dataset_fn(split, shuffle_files=False):
  # We only have one file for each split.
  del shuffle_files

  # Load lines from the text file as examples.
  ds = tf.data.TextLineDataset(nq_tsv_path[split])
  # Split each "<question>\t<answer>" example into (question, answer) tuple.
  ds = ds.map(
      functools.partial(tf.io.decode_csv, record_defaults=["", ""],
                        field_delim="\t", use_quote_delim=False),
      num_parallel_calls=tf.data.experimental.AUTOTUNE)
  # Map each tuple to a {"question": ... "answer": ...} dict.
  ds = ds.map(lambda *ex: dict(zip(["question", "answer"], ex)))
  return ds

print("A few raw validation examples...")
for ex in tfds.as_numpy(nq_dataset_fn("validation").take(5)):
  print(ex)

A few raw validation examples...
{'question': b'ARCCH question: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation? choice1: Planetary density will decrease. choice2: Planetary years will become longer. choice3: Planetary days will become shorter. choice4: Planetary gravity will become stronger.', 'answer': b'choice3: Planetary days will become shorter.'}
{'question': b'ARCCH question: A group of engineers wanted to know how different building designs would respond during an earthquake. They made several models of buildings and tested each for its ability to withstand earthquake conditions. Which will most likely result from testing different building designs? choice1: buildings will be built faster choice2: buildings will be made safer choice3: building designs will look nicer choice4: building materials will be cheaper', 'answer': b'choice2: buildings will be made safer'}
{'question': b'ARCCH ques

Now, we write a preprocess function to convert the examples in the `tf.data.Dataset` into a text-to-text format, with both `inputs` and `targets` fields. The preprocessor also normalizes the text by lowercasing it and removing quotes since the answers are sometimes formatted in odd ways. Finally, we prepend 'arc question:' to the inputs so that the model knows what task it's trying to solve.

In [0]:
def arc_preprocessor(ds):
  def normalize_text(text):
    """Lowercase and remove quotes from a TensorFlow string."""
    text = tf.strings.lower(text)
    text = tf.strings.regex_replace(text,"'(.*)'", r"\1")
    return text

  def to_inputs_and_targets(ex):
    """Map {"question": ..., "answer": ...}->{"inputs": ..., "targets": ...}."""
    return {
        "inputs":
             tf.strings.join(
                 ["arc question: ", normalize_text(ex["question"])]),
        "targets": normalize_text(ex["answer"])
    }
  return ds.map(to_inputs_and_targets, 
                num_parallel_calls=tf.data.experimental.AUTOTUNE)

Finally, we put everything together to create a `Task`.

In [0]:
t5.data.TaskRegistry.add(
    "arc_challenge",
    # Supply a function which returns a tf.data.Dataset.
    dataset_fn=nq_dataset_fn,
    splits=["train", "validation"],
    # Supply a function which preprocesses text from the tf.data.Dataset.
    text_preprocessor=[arc_preprocessor],
    # Use the same vocabulary that we used for pre-training.
    sentencepiece_model_path=t5.data.DEFAULT_SPM_PATH,
    # Lowercase targets before computing metrics.
    postprocess_fn=t5.data.postprocessors.lower_text, 
    # We'll use accuracy as our evaluation metric.
    metric_fns=[t5.evaluation.metrics.accuracy],
    # Not required, but helps for mixing and auto-caching.
    num_input_examples=num_nq_examples
)

Let's look at a few pre-processed examples from the validation set. Note they contain both the tokenized (integer) and plain-text inputs and targets.


In [0]:
arc_task = t5.data.TaskRegistry.get("arc_challenge")
ds = arc_task.get_dataset(split="validation", sequence_length={"inputs": 128, "targets": 32})
print("A few preprocessed validation examples...")
for ex in tfds.as_numpy(ds.take(5)):
  print(ex)

A few preprocessed validation examples...
{'inputs_plaintext': b'trivia question: arcch question: for many years, scientists have found fossils in rock layers in all parts of the world. eventually scientists discovered that many of the rock layers showed the same pattern of older fossils being located in deeper rock layers and younger fossils being located in shallower layers. this discovery is an example of how scientific knowledge changes as choice1: fossils are preserved in different ways. choice2: conclusions about fossils are revised into questions. choice3: fossils are formed in different areas. choice4: new fossil evidence is reviewed.', 'inputs': array([22377,   822,    10,     3,  4667,   524,   822,    10,    21,
         186,   203,     6,  7004,    43,   435, 15722,     7,    16,
        2480,  7500,    16,    66,  1467,    13,     8,   296,     5,
        3725,  7004,  3883,    24,   186,    13,     8,  2480,  7500,
        3217,     8,   337,  3275,    13,  2749, 15722,  

## Dataset Mixture

We now create a `Mixture` from the above `Tasks`, which we will fine-tune on.

There are different ways to automatically set the rate (for example, based on the number of examples using `rate_num_examples`), but we will just hardcode an equal mixture for simplicity.

In [0]:
t5.data.MixtureRegistry.remove("arc_challenge")
t5.data.MixtureRegistry.add(
    "arc_challenge",
    ["arc_challenge"],
     default_rate=1.0
)

# Transferring to new Tasks

We are now ready to fine-tune one of the pre-trained T5 models on our new mixture of closed-book QA tasks.

First, we'll instantiate a `Model` object using the model size of your choice. Note that larger models are slower to train and use but will likely achieve higher accuracy. You also may be able to increase accuracy by training longer with more `FINETUNE_STEPS` below.


## Caveats

* Due to its memory requirements, you will not be able to train the `11B` parameter model on the TPU provided by Colab. Instead, you will need to fine-tune inside of a GCP instance (see [README](https://github.com/google-research/text-to-text-transfer-transformer/)).
* Due to the checkpoint size, you will not be able use the 5GB GCS free tier for the `3B` parameter models. You will need at least 25GB of space, which you can purchase with your $300 of initial credit on GCP.
* While `large` can achieve decent results, it is recommended that you fine-tune at least the `3B` parameter model.


## Define Model

In [0]:
MODEL_SIZE = "3B" #@param["small", "base", "large", "3B", "11B"]
# Public GCS path for T5 pre-trained model checkpoints
BASE_PRETRAINED_DIR = "gs://t5-data/pretrained_models"
PRETRAINED_DIR = os.path.join(BASE_PRETRAINED_DIR, MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_SIZE)

if ON_CLOUD and MODEL_SIZE == "3B":
  tf.logging.warn(
      "The `3B` model is too large to use with the 5GB GCS free tier. "
      "Make sure you have at least 25GB on GCS before continuing."
  )
elif ON_CLOUD and MODEL_SIZE == "11B":
  raise ValueError(
      "The `11B` parameter is too large to fine-tune on the `v2-8` TPU "
      "provided by Colab. Please comment out this Error if you're running "
      "on a larger TPU."
  )

# Set parallelism and batch size to fit on v2-8 TPU (if possible).
# Limit number of checkpoints to fit within 5GB (if possible).
model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "base": (2, 128, 8),
    "large": (8, 64, 4),
    "3B": (8, 16, 1),
    "11B": (8, 16, 1)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)
# The models from our paper are based on the Mesh Tensorflow Transformer.
model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 128, "targets": 32},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max if ON_CLOUD else None,
    iterations_per_loop=100,
)



## Fine-tune

We are now ready to fine-tune our model. This will take a while (~2 hours with default settings), so please be patient! The larger the model and more `FINETUNE_STEPS` you use, the longer it will take.

Don't worry, you can always come back later and increase the number of steps, and it will automatically pick up where you left off.

In [0]:
FINETUNE_STEPS = 25000 #@param {type: "integer"}

model.finetune(
    mixture_or_task_name="arc_challenge",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)



INFO:tensorflow:Using config: {'_model_dir': 'gs://t5_nlp_arc/models/3B', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.57.79.146:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.57.79.146:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.57.79.146:8470', '_evaluation_master': 'grpc://10.57.79.146:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_t



INFO:tensorflow:enable_2d_tiling: False
INFO:tensorflow:num_cores_per_replica: 1
INFO:tensorflow:computation_shape: [1, 1, 1]
INFO:tensorflow:num_replicas: 8
INFO:tensorflow:device_assignment.topology.device_coordinates: [[[0 0 0]
  [0 0 1]
  [1 0 0]
  [1 0 1]
  [0 1 0]
  [0 1 1]
  [1 1 0]
  [1 1 1]]]
INFO:tensorflow:device_assignment.core_assignment: [[[0 0 0]]

 [[0 0 1]]

 [[0 1 0]]

 [[0 1 1]]

 [[1 0 0]]

 [[1 0 1]]

 [[1 1 0]]

 [[1 1 1]]]
INFO:tensorflow:SimdMeshImpl init: Shape[model=8] LayoutRules{('d_ff', 'model'), ('ensemble', 'ensemble'), ('batch', 'batch'), ('heads', 'model'), ('vocab', 'model'), ('experts', 'batch')}
INFO:tensorflow:Device Assignment: <tensorflow.python.tpu.device_assignment.DeviceAssignment object at 0x7fdb43c834a8>
INFO:tensorflow:serialize_num_microbatches: tokens_per_microbatch_per_replica=4096 batch_dim=Dimension(name='batch', size=16) sequence_length={'inputs': 128, 'targets': 32} batch_per_replica=16 num_microbatches=1
INFO:tensorflow:Create pnum_t

## Evaluate

We now evaluate on the validation sets of the tasks in our mixture. Accuracy results will be logged and added to the TensorBoard above.

In [0]:
# Use a larger batch size for evaluation, which requires less memory.
print("train_batch_size: " + str(train_batch_size))
model.batch_size = train_batch_size * 4
model.eval(
    mixture_or_task_name="arc_challenge",
    checkpoint_steps="all"
)



train_batch_size: 16
INFO:tensorflow:Using config: {'_model_dir': 'gs://t5_nlp_arc/models/3B', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
cluster_def {
  job {
    name: "worker"
    tasks {
      key: 0
      value: "10.57.79.146:8470"
    }
  }
}
isolate_session_state: true
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({'worker': ['10.57.79.146:8470']}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': 'grpc://10.57.79.146:8470', '_evaluation_master': 'grpc://10.57.79.146:8470', '_is_chief': True, '_num_ps_replicas': 0, '_num_wor



INFO:tensorflow:Checkpoint path gs://t5_nlp_arc/models/3B/model.ckpt-1025000
INFO:tensorflow:Querying Tensorflow master (grpc://10.57.79.146:8470) for TPU system metadata.
INFO:tensorflow:Initializing TPU system (master: grpc://10.57.79.146:8470) to fetch topology for model parallelism. This might take a while.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 11164394008714535784)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 16752390102529776813)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 2854755398848372869)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 34013

Let's look at a few random predictions from the validation sets. Note that we measure accuracy based on an *exact match* of the predicted answer and the ground-truth answer. As a result, some of the answers are semantically correct but are counted wrong by the exact match score.

In [0]:
import random

def print_random_predictions(task_name, n=10):
  """Print n predictions from the validation split of a task."""
  # Grab the dataset for this task.
  ds = t5.data.TaskRegistry.get(task_name).get_dataset(
      split="validation",
      sequence_length={"inputs": 128, "targets": 32},
      shuffle=False)

  def _prediction_file_to_ckpt(path):
    """Extract the global step from a prediction filename."""
    return int(path.split("_")[-2])

  # Grab the paths of all logged predictions.
  prediction_files = tf.io.gfile.glob(
      os.path.join(
          MODEL_DIR,
          "validation_eval/%s_*_predictions" % task_name))
  # Get most recent prediction file by sorting by their step.
  latest_prediction_file = sorted(
      prediction_files, key=_prediction_file_to_ckpt)[-1]

  # Collect (inputs, targets, prediction) from the dataset and predictions file
  results = []
  with tf.io.gfile.GFile(latest_prediction_file) as preds:
    for ex, pred in zip(tfds.as_numpy(ds), preds):
      results.append((tf.compat.as_text(ex["inputs_plaintext"]),
                      tf.compat.as_text(ex["targets_plaintext"]),
                      pred.strip()))

  print("<== Random predictions for %s using checkpoint %s ==>\n" %
        (task_name, 
         _prediction_file_to_ckpt(latest_prediction_file)))

  for inp, tgt, pred in random.choices(results, k=10):
    print("Input:", inp)
    print("Target:", tgt)
    print("Prediction:", pred)
    print("Counted as Correct?", tgt == pred)
    print()

print_random_predictions("arc_challenge")

<== Random predictions for arc_challenge using checkpoint 1025000 ==>

Input: trivia question: arcch question: which of these is most likely the hardest to bend? choice1: rubber band choice2: cloth ribbon choice3: leather shoe choice4: wooden board
Target: choice4: wooden board
Prediction: choice4: wooden board
Counted as Correct? True

Input: trivia question: arcch question: there are four seasons in a year. which reason is most responsible for the changing seasons on earth? choice1: the way earth tilts on its axis choice2: the way earth rotates on its axis choice3: changes in the distance between earth and the sun choice4: changes in the amount of energy the sun produces
Target: choice1: the way earth tilts on its axis
Prediction: choice1: the way earth tilts on its axis
Counted as Correct? True

Input: trivia question: arcch question: which is a renewable resource? choice1: oil choice2: coal choice3: trees choice4: aluminum
Target: choice3: trees
Prediction: choice3: trees
Counted a

## Predict

Now that we have fine-tuned the model, we can feed T5 arbitrary questions and have it predict the answers!

There is a significant amount of overhead in initializing the model so this may take a few minutes to run each time even though the prediction itself is quite fast.


To avoid this overhead, you might consider exporting a `SavedModel` and running it on [Cloud ML Engine](https://cloud.google.com/ml-engine/).



In [0]:
question_1 = "ARCCH question: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation? choice1: Planetary density will decrease. choice2: Planetary years will become longer. choice3: Planetary days will become shorter. choice4: Planetary gravity will become stronger." #@param {type:"string"}
question_2 = "ARCCH question: Which of these gases is the most abundant greenhouse gas in the lower atmosphere of Earth? choice1: ozone choice2: methane choice3: water vapor choice4: carbon dioxide" #@param {type:"string"}
question_3 = "ARCCH question: A scientist maps a long region in which earthquakes originate and determines this region is a transform plate boundary. Which evidence would cause the scientist to reevaluate this determination? choice1: Volcanism also characterizes the region. choice2: Earthquake centers in the region occur at shallow depths. choice3: The region shows extensive faulting of sediments. choice4: Equal crust densities are found on opposite sides of the region." #@param {type:"string"}

questions = [question_1, question_2, question_3]

now = time.time()
# Write out the supplied questions to text files.
predict_inputs_path = os.path.join(MODEL_DIR, "predict_inputs_%d.txt" % now)
predict_outputs_path = os.path.join(MODEL_DIR, "predict_outputs_%d.txt" % now)
# Manually apply preprocessing by prepending "triviaqa question:".
with tf.io.gfile.GFile(predict_inputs_path, "w") as f:
  for q in questions:
    f.write("arc question: %s\n" % q.lower())

# Ignore any logging so that we only see the model's answers to the questions.
with tf_verbosity_level('ERROR'):
  model.batch_size = 8  # Min size for small model on v2-8 with parallelism 1.
  model.predict(
      input_file=predict_inputs_path,
      output_file=predict_outputs_path,
      # Select the most probable output token at each step.
      temperature=0,
  )

# The output filename will have the checkpoint appended so we glob to get 
# the latest.
prediction_files = sorted(tf.io.gfile.glob(predict_outputs_path + "*"))
print("\nPredictions using checkpoint %s:\n" % prediction_files[-1].split("-")[-1])
with tf.io.gfile.GFile(prediction_files[-1]) as f:
  for q, a in zip(questions, f):
    if q:
      print("Q: " + q)
      print("A: " + a)
      print()

# Export Model for Serving

As mentioned in the previous section, exporting a [`SavedModel`](https://www.tensorflow.org/guide/saved_model) can be useful for improving performance during inference or allowing your model to be deployed on a variety of platforms (e.g., TFLite, TensorFlow.js, TensorFlow Serving, or TensorFlow Hub).

**Note:** we currently only support exporting a SavedModel that runs on both CPU and GPU, not TPU.

## Export SavedModel

We first export the SavedModel. We set a batch size of 1 for simplicity, but it may be more efficient to use a larger batch size if you want to handle multiple requests per call.

For 3B and 11B models the export will take approximately 30-45 minutes.

In [0]:
export_dir = os.path.join(MODEL_DIR, "export")

model.batch_size = 1 # make one prediction per call
saved_model_path = model.export(
    export_dir,
    checkpoint_step=-1,  # use most recent
    beam_size=1,  # no beam search
    temperature=1.0,  # sample according to predicted distribution
)
print("Model saved to:", saved_model_path)

## Load SavedModel

One way to test our model is to load it either in eager mode or a TF 1.x session so that we can repeatedly predict from the model without the overhead of loading the graph and weights each time.

We pay the overhead once here, but it shouldn't take more than a few minutes.


### Optional: Switch to GPU Runtime

Changing the runtime type to GPU in the `Runtime` menu above before loading the SavedModel will speed up inference by using the GPU instead of CPU.



In [0]:
#@title Optional: Run this cell to re-initialize if you switched to GPU runtime.
%tensorflow_version 2.x
!pip install tensorflow-text
from google.colab import auth
auth.authenticate_user()

In [0]:
import tensorflow as tf
import tensorflow_text  # Required to run exported model.

def load_predict_fn(model_path):
  if tf.executing_eagerly():
    print("Loading SavedModel in eager mode.")
    imported = tf.saved_model.load(model_path, ["serve"])
    return lambda x: imported.signatures['serving_default'](tf.constant(x))['outputs'].numpy()
  else:
    print("Loading SavedModel in tf 1.x graph mode.")
    tf.compat.v1.reset_default_graph()
    sess = tf.compat.v1.Session()
    meta_graph_def = tf.compat.v1.saved_model.load(sess, ["serve"], model_path)
    signature_def = meta_graph_def.signature_def["serving_default"]
    return lambda x: sess.run(
        fetches=signature_def.outputs["outputs"].name, 
        feed_dict={signature_def.inputs["input"].name: x}
    )

predict_fn = load_predict_fn(saved_model_path)

## Predict

We can now call the predict method with different inputs each time and relatively quickly get results.

In [0]:
def answer(question):
  return predict_fn([question])[0].decode('utf-8')

for question in ["arc question: ARCCH question: An astronomer observes that a planet rotates faster after a meteorite impact. Which is the most likely effect of this increase in rotation? choice1: Planetary density will decrease. choice2: Planetary years will become longer. choice3: Planetary days will become shorter. choice4: Planetary gravity will become stronger.",
                 "arc question: ARCCH question: Which of these gases is the most abundant greenhouse gas in the lower atmosphere of Earth? choice1: ozone choice2: methane choice3: water vapor choice4: carbon dioxide",
                 "arc question: ARCCH question: A scientist maps a long region in which earthquakes originate and determines this region is a transform plate boundary. Which evidence would cause the scientist to reevaluate this determination? choice1: Volcanism also characterizes the region. choice2: Earthquake centers in the region occur at shallow depths. choice3: The region shows extensive faulting of sediments. choice4: Equal crust densities are found on opposite sides of the region."
                 ]:
    print(answer(question))