# Fine Tuning SQuAD with T5x on Vertex AI

## Introduction

This page outlines the steps to fine-tune an existing pre-trained model with T5X on common downstream tasks defined with SeqIO using Vertex AI Training. This is one of the simplest and most common use cases of T5X. If you're new to T5X, this tutorial is the recommended starting point.

## Overview

Fine-tuning a model with T5X consists of the following steps:

1. Choose the pre-trained model to fine-tune.
2. Choose the SeqIO Task/Mixture to fine-tune the model on.
3. Write a Gin file that configures the pre-trained model, SeqIO Task/Mixture and other details of your fine-tuning run.
4. Configure a Vertex AI Training job to fine tune the model.
5. Monitor your job and parse metrics.

These steps are explained in detail in the following sections. An example run that fine-tunes a T5-small checkpoint on WMT14 English to German translation benchmark is also showcased.

### Step 1: Choose a pre-trained model

To use a pre-trained model, you need a Gin config file that defines the model params, and the model checkpoint to load from. For your convenience, TensorFlow checkpoints and Gin configs for common T5 pre-trained models have been made available for use in T5X. A list of all the available pre-trained models (with model checkpoints and Gin config files) are available in the Models documentation.

For the example run, you will use the T5 1.1 Small model. The Gin file for this model is located at `/t5x/examples/t5/t5_1_1/small.gin`, and the checkpoint is located at `gs://t5-data/pretrained_models/t5x/t5_1_1_small`.

### Step 2: Choose a SeqIO Task/Mixture

A SeqIO Task encapsulates the data source, the preprocessing logic to be performed on the data before querying the model, the postprocessing logic to be performed on model outputs, and the metrics to be computed given the postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks and enables fine-tuning a model on multiple Tasks simultaneously.

#### Standard Tasks
Many common datasets and benchmarks, e.g. GLUE, SuperGLUE, WMT, SQUAD, CNN/Daily Mail, etc. have been implemented as SeqIO Tasks/Mixtures and can be used directly.
For the example run, you will fine-tune the model on the SQuAD Q&A benchmark, which has been implemented as the `squad_v010_allanswers` Task.

The details of the implementation can be found here:
https://github.com/google-research/text-to-text-transfer-transformer/blob/7db665af4fe395398a0fc20038632584cca2a99a/t5/data/tasks.py#L336

#### Custom Tasks
It is also possible to define your own custom task. See the SeqIO documentation for how to do this.  
When defining a custom task, you have the option to cache it on disk before fine-tuning. Caching may improve performance for tasks with expensive pre-processing. By default, T5X expects tasks to be cached. To finetune on a task that has not been cached, set `--gin.USE_CACHED_TASKS=False`.

### Step 3: Write a Gin Config

After choosing the pre-trained model and SeqIO Task/Mixture for your run, the next step is to configure your run using Gin. If you're not familiar with Gin, reading the T5X Gin Primer is recommended.

T5X provides a Gin file that configures the T5X trainer for fine-tuning (located at `t5x/configs/runs/finetune.gin`), and expects a few params from you. These params can be specified in a separate Gin file, or via commandline flags. Following are the required params:

 - `INITIAL_CHECKPOINT_PATH`: This is the path to the pre-trained checkpoint (from Step 1). For the example run, set this to `gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000`.
 - `TRAIN_STEPS`: Number of fine-tuning steps. This includes the number of steps that the model was pre-trained for, so make sure to add the step number from the   `INITIAL_CHECKPOINT_PATH`. For the example run, to fine-tune for `20_000` steps, set this to `1_020_000`, since the initial checkpoint is the `1_000_000th` step.
 - `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run (from Step 2). For the example run, set this to `squad_v010_allanswers`.
 - `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int length for that feature. After preprocessing, features are truncated to the provided value. For the example run, set this to `{'inputs': 256, 'targets': 256}`.
 - `MODEL_DIR`: A path to write fine-tuned checkpoints to. In this case, a path to Google Cloud Storage.
 - `LOSS_NORMALIZING_FACTOR`: When fine-tuning a model that was pre-trained using Mesh Tensorflow (e.g. the public T5 / mT5 / ByT5 models), this should be set to pretraining `batch_size * pretrained target_token_length`. For T5 and T5.1.1: `2048 * 114`. For mT5: `1024 * 229`. For ByT5: `1024 * 189`.

 In addition to the above params, you will need to include `finetune.gin` and the Gin file for the pre-trained model, which for the example run is `t5_1_1/small.gin`.

```
include 't5x/configs/runs/finetune.gin'
include 't5x/examples/t5/t5_1_1/small.gin'
```

You will also need to import the Python module(s) that register SeqIO Tasks and Mixtures used in your run. For the example run, we add import `t5.data.tasks` since it is where `squad_v010_allanswers` is registered.

Finally, your Gin file should look like this:

```
include 't5x/configs/runs/finetune.gin'
include 't5x/examples/t5/t5_1_1/small.gin'

# Register necessary SeqIO Tasks/Mixtures.
import t5.data.tasks

MIXTURE_OR_TASK_NAME = "squad_v010_allanswers"
TASK_FEATURE_LENGTHS = {"inputs": 256, "targets": 256}
TRAIN_STEPS = 1_020_000  # 1000000 pre-trained steps + 20000 fine-tuning steps.
DROPOUT_RATE = 0.0
INITIAL_CHECKPOINT_PATH = "gs://t5-data/pretrained_models/t5x/t5_1_1_small/checkpoint_1000000"
LOSS_NORMALIZING_FACTOR = 233472
```

See `t5x/examples/t5/t5_1_1/examples/small_wmt_finetune.gin` for this example.

### Step 4: Configure and launch a Vertex AI Training job to fine tune the model

#### Import required python packages

In [1]:
import os
import time
from datetime import datetime

from google.cloud import aiplatform as vertex_ai

#### Define variables for training job

In [2]:
# Project definitions
PROJECT_ID = 'renatoleite-dev' # Change to your project id.
REGION = 'us-central1'  # Change to your region.

# Bucket definitions
BUCKET = 'rl-language' # Change to your bucket.

In [3]:
# Bucket definitions
VERSION = 'v01'
MODEL_NAME = 'finetune-squad'
MODEL_DISPLAY_NAME = f'{MODEL_NAME}-{VERSION}'
WORKSPACE = f'gs://{BUCKET}/{MODEL_DISPLAY_NAME}'

# Docker definitions for training
IMAGE_NAME = 't5x-training'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

#### Initialize Vertex AI client and log metadata

In [4]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M")

EXPERIMENT_ID = f'{MODEL_DISPLAY_NAME}-{TIMESTAMP}'
EXECUTION_NAME = f'execution-1'
RUN_NAME = 'run-1'

In [5]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging',
    experiment=EXPERIMENT_ID
)

In [None]:
vertex_ai.start_run(RUN_NAME)

In [7]:
metaparams = {}
metaparams['emb_dim'] = 512
metaparams['num_heads'] = 6
metaparams['num_encoder_layers'] = 8
metaparams['num_decoder_layers'] = 8
metaparams['head_dim'] = 64
metaparams['mlp_dim'] = 1024
metaparams['inputs_feature_len'] = 256
metaparams['outputs_feature_len'] = 256
metaparams['mixture_task_name'] = 'squad_v010_allanswers'
vertex_ai.log_params(metaparams)

hyperparams = {}
hyperparams['train_steps'] = 1_020_000
hyperparams['dropout_rate'] = 0.0
hyperparams['loss_normalizing_factor'] = 233472
vertex_ai.log_params(hyperparams)

In [8]:
dataset_seqio_artifact = vertex_ai.Artifact.create(
    schema_title="system.Dataset", display_name='seqio_task_mixture', uri='squad_v010_allanswers'
)

model_artifact = vertex_ai.Artifact.create(
    schema_title="system.Model", display_name='squad_finetuned_model', uri=WORKSPACE
)

In [9]:
with vertex_ai.start_execution(
    schema_title="system.ContainerExecution", display_name=EXECUTION_NAME
) as execution:
    execution.assign_input_artifacts([dataset_seqio_artifact])
    execution.assign_output_artifacts([model_artifact])

    vertex_ai.log_metrics(
        {"lineage": execution.get_output_artifacts()[0].lineage_console_uri}
    )

#### Build docker image to run the training job

In [None]:
! gcloud builds submit --tag {IMAGE_URI} --timeout=2h

#### Define infra and submit job

In [None]:
MACHINE_TYPE = 'cloud-tpu'
ACCELERATOR_TYPE = 'TPU_V3'
ACCELERATOR_NUM = 8
REPLICA_COUNT = 1

In [None]:
# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR = f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}'

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR = f'gs://{BUCKET}/dataset/{MODEL_DISPLAY_NAME}'
GIN_FILE = './small_finetune_squad.gin'

In [None]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_NUM,
        },
        "replica_count": REPLICA_COUNT,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "command": ["/opt/conda/envs/t5x/bin/python", "/llm/t5x/t5x/train.py"],
            "args": [
                f'--gin_file={GIN_FILE}',
                f'--gin.MODEL_DIR="{MODEL_DIR}"',
                f'--tfds_data_dir={TFDS_DATA_DIR}',
                '--gin.USE_CACHED_TASKS=False'
            ],
        },
    }
]

In [None]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  os.path.join(WORKSPACE, job_name)

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)

job.run(
    sync=False
)

### Step 5: Explore and log metrics

After fine-tuning has completed, you can parse metrics into CSV format using the following script:

In [10]:
GCS_VAL_DIR=f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}/inference_eval'
VAL_DIR = '/home/renatoleite/workspace/t5x-sandbox/rl-tests/finetune_small_squad/inference_eval'

In [12]:
! mkdir {VAL_DIR}
! gsutil -m cp -r {GCS_VAL_DIR} {VAL_DIR}

Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1001000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1002000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1003000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1004000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1015000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1011000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1006000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1020000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval/squad_v010_allanswers-1016000.jsonl...
Copying gs://rl-language/model/finetune-squad-v01/inference_eval

In [13]:
! python -m t5.scripts.parse_tb \
  --summary_dir={VAL_DIR} \
  --seqio_summaries \
  --out_file=./results.csv \
  --alsologtostderr

 This a JAX bug; please report an issue at https://github.com/google/jax/issues
  _warn(f"cloud_tpu_init failed: {repr(exc)}\n This a JAX bug; please report "
2022-07-13 14:02:27.723792: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
I0713 14:02:27.726068 140507816245056 parse_tb.py:62] No evaluation events found in /home/renatoleite/workspace/t5x-sandbox/rl-tests/finetune_small_squad/inference_eval


### Metric Explanations

By default, t5x logs many metrics to TensorBoard, many of these seem similar but
have important distinctions.

The first two graphs you will see are the `accuracy` and `cross_ent_loss`
graphs. These are the *token-level teacher-forced* accuracy and cross entropy
loss respectively. Each of these graphs can have multiple curves on them. The
first curve is the `train` curve. This is calculated as a running sum than is
then normalized over the whole training set. The second class of curves have the
form `training_eval/${task_name}`. These curves are created by running a subset
(controlled by the `eval_steps` parameter of the main train function) of the
validation split of `${task_name}` through the model and calculating these
metrics using teacher-forcing. These graphs can commonly be used to find
"failure to learn" cases and as a warning sign of overfitting, but these are
often not the final metrics one would report on.

The second set of graphs are the ones under the collapsible `eval` section in
TensorBoard. These graphs are created based on the `metric_fns` defined in the
SeqIO task. The curves on these graphs have the form
`inference_eval/${task_name}`. Values are calculated by running the whole
validation split through the model in inference mode, commonly auto-regressive
decoding or output scoring. Most likely these are the metrics that will be
reported.

More information about the configuration of the datasets used for these
different metrics can be found [here](#train-train-eval-and-infer-eval).

In summary, the metric you actually care about most likely lives under the
`eval` tab rather, than in the `accuracy` graph.