Fine-tune T5 with in-context objective.

This notebook is adapted from the original T5 [repo](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/notebooks/t5-trivia.ipynb).

# Set up

In [None]:
print("Installing dependencies...")
%tensorflow_version 2.x
!pip install seqio==0.0.12
!pip install -q t5

import functools
import os
import sys
import time
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds

import t5
import t5.models
import seqio

sys.argv = sys.argv[:1]

BASE_DIR = "gs://in-context-tuning"
DATA_DIR = os.path.join(BASE_DIR, "data")
MODELS_DIR = os.path.join(BASE_DIR, "models")

# Set credentials for GCS reading/writing from Colab and TPU.
print("Setting up GCS access...")
os.environ['USE_AUTH_EPHEM'] = '0'
TPU_TOPOLOGY = "v2-8"
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
TPU_ADDRESS = tpu.get_master()
print('Running on TPU:', TPU_ADDRESS)
tf.enable_eager_execution()
tf.config.experimental_connect_to_host(TPU_ADDRESS)
tf.disable_v2_behavior()

from google.colab import auth
auth.authenticate_user()

# Improve logging.
from contextlib import contextmanager
import logging as py_logging
tf.get_logger().propagate = False
py_logging.root.setLevel('INFO')

@contextmanager
def tf_verbosity_level(level):
  og_level = tf.logging.get_verbosity()
  tf.logging.set_verbosity(level)
  yield
  tf.logging.set_verbosity(og_level)

# Define task

In [None]:
import json
import functools
import pandas as pd
import t5
import seqio
import tensorflow.compat.v1 as tf

train_tsv_file = "train-random.tsv" #@param { type: "string" }
validation_tsv_file = "dev-random.tsv" #@param { type: "string" }
test_tsv_file = "test-random.tsv" #@param { type: "string" }
num_examples_json_file = "counts-random.json" #@param { type: "string" }
MAX_TARGET_LEN = 32  #@param { type: "integer" }

DEFAULT_OUTPUT_FEATURES = {
    "inputs": seqio.Feature(vocabulary=t5.data.get_default_vocabulary(), add_eos=True),
    "targets": seqio.Feature(vocabulary=t5.data.get_default_vocabulary(), add_eos=True)
}

TSV_PATH = {
    "train": os.path.join(DATA_DIR, train_tsv_file),
    "validation": os.path.join(DATA_DIR, validation_tsv_file) if validation_tsv_file else None,
    "test": os.path.join(DATA_DIR, test_tsv_file)
}
num_examples = json.load(tf.io.gfile.GFile(os.path.join(DATA_DIR, num_examples_json_file)))

def read_prompt_dict(filename: str) -> dict:
    result = {}
    df = pd.read_csv(
        tf.io.gfile.GFile(filename), header=None, sep="\t",
        names=["task_name", "task_prefix", "prompt", "prompt_len", "io_sep"])
    for _, row in df.iterrows():
        result[row.task_prefix] = (row.prompt, row.io_sep)
    return result

PROMPT_DICT = read_prompt_dict(os.path.join(DATA_DIR, "prompt/prompt.tsv"))


def preprocessor_fn(ds):
    def normalize_text(text):
        """Lowercase and remove quotes from a TensorFlow string."""
        text = tf.strings.lower(text)
        text = tf.strings.regex_replace(text, "'(.*)'", r"\1")
        return text

    def to_inputs_and_targets(ex):
        """Map {"input": ..., "target": ...}->{"inputs": ..., "targets": ...}."""
        return {
            "inputs": normalize_text(ex["input"]),
            "targets": normalize_text(ex["target"])
        }

    return ds.map(to_inputs_and_targets, 
                  num_parallel_calls=tf.data.experimental.AUTOTUNE)


def dataset_fn(split, shuffle_files=False):
    del shuffle_files  # We only have one file for each split.

    df = pd.read_csv(
        tf.io.gfile.GFile(TSV_PATH[split]), header=None, sep="\t")
    df = df[range(4)]  # Only take the first 4 columns.
    df.columns = ["task_name", "task_prefix", "input", "target"]
    lines = []
    for _, row in df.iterrows():
        prompt_prefix, io_sep = PROMPT_DICT[row.task_prefix]
        input_text = prompt_prefix + " " + row.input + " " + io_sep
        lines.append(input_text + "\t" + str(row.target))
    ds = tf.data.Dataset.from_tensor_slices(lines)
    # Split each "<input>\t<target>" example into (input, target) tuple.
    ds = ds.map(
        functools.partial(
            tf.io.decode_csv, record_defaults=["", ""],
            field_delim="\t", use_quote_delim=False
        ), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Map each tuple to a {"input": ... "target": ...} dict.
    ds = ds.map(lambda *ex: dict(zip(["input", "target"], ex)))
    return ds


splits = ["train", "validation", "test"] if validation_tsv_file else ["train", "test"]
seqio.TaskRegistry.add(
    "in_context_tuning",
    source=seqio.FunctionDataSource(
        dataset_fn=dataset_fn,
        splits=splits,
        num_input_examples=num_examples),
    preprocessors=[
        preprocessor_fn,
        seqio.preprocessors.tokenize_and_append_eos,
    ],
    postprocess_fn=t5.data.postprocessors.lower_text,
    metric_fns=[t5.evaluation.metrics.accuracy],
    output_features=DEFAULT_OUTPUT_FEATURES,
)

# Define model

In [None]:
MODEL_SIZE = "t5.1.1.large" #@param["small", "base", "large", "t5.1.1.small", "t5.1.1.base", "t5.1.1.large"]
MODEL_DIR_NAME = "train_random_large1.1" #@param { type: "string" }
PRETRAINED_DIR = os.path.join("gs://t5-data/pretrained_models", MODEL_SIZE)
MODEL_DIR = os.path.join(MODELS_DIR, MODEL_DIR_NAME)

model_parallelism, train_batch_size, keep_checkpoint_max = {
    "small": (1, 256, 16),
    "t5.1.1.small": (1, 256, 16),
    "base": (2, 128, 8),
    "t5.1.1.base": (2, 128, 8),
    "large": (8, 64, 4),
    "t5.1.1.large": (8, 64, 4)}[MODEL_SIZE]

tf.io.gfile.makedirs(MODEL_DIR)

model = t5.models.MtfModel(
    model_dir=MODEL_DIR,
    tpu=TPU_ADDRESS,
    tpu_topology=TPU_TOPOLOGY,
    model_parallelism=model_parallelism,
    batch_size=train_batch_size,
    sequence_length={"inputs": 1024, "targets": MAX_TARGET_LEN},
    learning_rate_schedule=0.003,
    save_checkpoints_steps=5000,
    keep_checkpoint_max=keep_checkpoint_max,
    iterations_per_loop=100,
)

# Fine tune

In [None]:
if ON_CLOUD:
  %reload_ext tensorboard
%tensorboard --logdir="$MODEL_DIR" --port=0

In [None]:
FINETUNE_STEPS = 30000 #@param {type: "integer"}

model.finetune(
    mixture_or_task_name="in_context_tuning",
    pretrained_model_dir=PRETRAINED_DIR,
    finetune_steps=FINETUNE_STEPS
)

# Predict test set

In [None]:
CHECKPOINT_STEPS = 1030000 #@param {type: "integer"}

model.eval(
    mixture_or_task_name="in_context_tuning",
    checkpoint_steps=CHECKPOINT_STEPS,
    split="test",
    compute_sequence_length=False
)