# Pretraining a model

## Introduction

This page outlines the steps to pretrain a model with T5X on common tasks defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md) on Vertex AI Training.

## Overview

Pretraining a model with T5X on Vertex AI Training consists of the following steps:

1.  Choose the model architecture.
2.  Choose the SeqIO Task/Mixture to for training.
3.  Write a Gin file that configures the model, SeqIO Task/Mixture and other
    details of your pretraining run.
4.  Configure and launch a training job on Vertex AI.
5.  Monitor your job.

These steps are explained in detail in the following sections. An example run that trains a T5 1.1 Small checkpoint from scratch on the C4 dataset using the span corruption pretraining objective is also showcased.

## Step 1: Choose a model architecture

To train a model, you need a Gin config file that defines the model params. For
your convenience, Gin configs for common models have been made available for use
in T5X. Following is a list of these models and their Gin locations.

Model                                 | Gin File Location
------------------------------------- | -----------------
T5 Small                              | [t5_1_0/small.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_0/small.gin)
T5 Base                               | [t5_1_0/base.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_0/base.gin)
T5 Large                              | [t5_1_0/large.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_0/large.gin)
T5 3B                                 | [t5_1_0/3B.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_0/3B.gin)
T5 11B                                | [t5_1_0/11B.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_0/11B.gin)
T5 1.1 Small                          | [t5_1_1/small.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/small.gin)
T5 1.1 Base                           | [t5_1_1/base.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/base.gin)
T5 1.1 Large                          | [t5_1_1/large.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/large.gin)
T5 1.1 XL                             | [t5_1_1/xl.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/xl.gin)
T5 1.1 XXL                            | [t5_1_1/xxl.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/xxl.gin)
MT5 Small                             | [mt5/small.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/mt5/small.gin)
MT5 Base                              | [mt5/base.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/mt5/base.gin)
MT5 Large                             | [mt5/large.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/mt5/large.gin)
MT5 XL                                | [mt5/xl.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/mt5/xl.gin)
MT5 XXL                               | [mt5/xxl.gin](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/mt5/xxl.gin)

For the example run, you will use the T5 1.1 Small model. The Gin file for this
model is located at
[`/t5x/examples/t5/t5_1_1/1_1_small.gin`](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/small.gin).

## Step 2: Choose a SeqIO Task/Mixture

A SeqIO Task encapsulates the data source, the preprocessing logic to be
performed on the data before querying the model, the postprocessing logic to be
performed on model outputs, and the metrics to be computed given the
postprocessed outputs and targets. A SeqIO Mixture denotes a collection of Tasks
and enables pretraining a model on multiple Tasks simultaneously.

Many common datasets and benchmarks, e.g. [GLUE](https://gluebenchmark.com/),
[SuperGLUE](https://super.gluebenchmark.com/),
[WMT](https://www.tensorflow.org/datasets/catalog/wmt_t2t_translate),
[SQUAD](https://rajpurkar.github.io/SQuAD-explorer/),
[CNN/Daily Mail](https://github.com/abisee/cnn-dailymail), etc. have been
implemented as SeqIO Tasks/Mixtures and can be used directly. These
Tasks/Mixtures are defined in
[`third_party/py/t5/data/tasks.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py)
and
[`third_party/py/t5/data/mixtures.py`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/mixtures.py).

For the example run, you will train the model on
[`c4_v220_span_corruption`](https://github.com/google-research/text-to-text-transfer-transformer/tree/main/t5/data/tasks.py?l=42&rcl=370153959)
Task that implements the span corruption pretraining objective using the C4
dataset. This is the final pretraining Task used in the
[T5 paper](https://arxiv.org/pdf/1910.10683.pdf%C3%82%C2%A0).

TIP: Want to use a custom Task or Mixture? See section below called "Adding
SeqIO Task/Mixture modules and Gin files"

## Step 3: Write a Gin Config

After choosing the model architecture and SeqIO Task/Mixture for your run, the
next step is to configure your run using Gin. If you're not familiar with Gin,
reading the [T5X Gin Primer](gin.md) is recommended.

T5X provides a Gin file that configures the T5X trainer for pretraining (located
at [`runs/pretrain.gin`](https://github.com/google-research/t5x/tree/main/t5x/configs/runs/pretrain.gin)),
and expects a few params from you. These params can be specified in a separate
Gin file, or via commandline flags. Following are the required params:

+   `TRAIN_STEPS`: Number of training steps. For the example run, set this to
    `100_000`.
+   `MIXTURE_OR_TASK_NAME`: This is the SeqIO Task or Mixture name to run (from
    Step 2). For the example run, set this to `'c4_v220_span_corruption'`.
+   `TASK_FEATURE_LENGTHS`: This is a dict mapping feature key to maximum int
    length for that feature. After preprocessing, features are truncated to the
    provided value. For the example run, set this to `{"inputs": 512, "targets":
    114}`, following the original T5 pretraining setup.
+   `MODEL_DIR`: A path to write pretrained checkpoints to.

In addition to the above params, you will need to import
[`pretrain.gin`](https://github.com/google-research/t5x/tree/main/t5x/configs/runs/pretrain.gin)
and the Gin file for the pretrained model, which for the example run is
[`t5_1_1/small.gin`](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/small.gin).

```gin
include 't5x/configs/runs/pretrain.gin'
include 't5x/examples/t5/t5_1_1/small.gin'
```

Note that the `include` statements can use relative paths in this example for
which you will pass an appropriate `gin_search_paths` flag to locate these files
when launching your run. However, we recommend that you use absolute paths.

You will also need to import the Python module(s) that register SeqIO Tasks and
Mixtures used in your run. For the example run, we add `import t5.data.mixtures`
since it is where 'glue_v002_proportional' is registered. Note that this module
must also be included as a dependency in the T5X trainer
[binary](https://github.com/google-research/t5x/tree/main/t5x/BUILD;l=74;rcl=398627055). Most
common Task/Mixture modules, such as this one, are already included. If your
module is not included, see the [Advanced Topics section](#custom-t5x-binaries)
at the end of this tutorial for instructions to add it.

Finally, your Gin file should look like this:

```gin
include 't5x/examples/t5/t5_1_1/small.gin'
include 't5x/configs/runs/pretrain.gin'

# Register necessary SeqIO Tasks/Mixtures.
import t5.data.mixtures

MIXTURE_OR_TASK_NAME = "c4_v220_span_corruption"
TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
TRAIN_STEPS = 10000
DROPOUT_RATE = 0.0
BATCH_SIZE = 256
```

See
[`t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin`](https://github.com/google-research/t5x/tree/main/t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin)
for this example.

## Step 4: Configure and launch a Vertex AI Training job to fine tune the model

In [None]:
import os
import time

from google.cloud import aiplatform as vertex_ai

In [None]:
# Project definitions
PROJECT_ID = 'renatoleite-dev' # Change to your project id.
REGION = 'us-central1'  # Change to your region.

# Bucket definitions
BUCKET = 'rl-language' # Change to your bucket.

In [None]:
# Bucket definitions
VERSION = 'v01'
MODEL_NAME = 'pretrain-en-de'
MODEL_DISPLAY_NAME = f'{MODEL_NAME}-{VERSION}'
WORKSPACE = f'gs://{BUCKET}/{MODEL_DISPLAY_NAME}'

# Docker definitions for training
IMAGE_NAME = 't5x-training'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

In [None]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging'
)

In [None]:
! gcloud builds submit --tag {IMAGE_URI} --timeout=2h

In [None]:
MACHINE_TYPE = 'cloud-tpu'
ACCELERATOR_TYPE = 'TPU_V3'
ACCELERATOR_NUM = 8
REPLICA_COUNT = 1

In [None]:
# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR = f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}'

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR = f'gs://{BUCKET}/dataset/{MODEL_DISPLAY_NAME}'
GIN_FILE = './base_wmt_from_scratch.gin'

In [None]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_NUM,
        },
        "replica_count": REPLICA_COUNT,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "command": ["/opt/conda/envs/t5x/bin/python", "/llm/t5x/t5x/train.py"],
            "args": [
                f'--gin_file={GIN_FILE}',
                f'--gin.MODEL_DIR="{MODEL_DIR}"',
                f'--tfds_data_dir={TFDS_DATA_DIR}',
                '--gin.USE_CACHED_TASKS=False'
            ],
        },
    }
]

In [None]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  os.path.join(WORKSPACE, job_name)

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)
job.run(
    sync=False
)