In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Fine Tuning T5 1.1 XL on WMT task

In [None]:
import os
import time
from datetime import datetime
import pandas as pd

from google.cloud import aiplatform as vertex_ai

## Configure environment settings

In [None]:
PROJECT_ID = 'jk-mlops-dev' # Change to your project id.
REGION = 'us-central1'  # Change to your region.
BUCKET = 'jk-t5x-staging' # Change to your bucket.
TENSORBOARD_NAME = '
TENSORBOARD_ID = ! gcloud ai tensorboards list --filter="displayName={TENSORBOARD_NAME}" --format="value(name)" --region=us-central1 --limit=1
TENSORBOARD_ID = TENSORBOARD_ID[1]
IMAGE_NAME = 't5x-base'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

## Configure experiment settings

In [None]:
EXPERIMENT_NAME = 'fine-tune-t5-xl'

EXPERIMENT_WORKSPACE = f'gs://{BUCKET}/experiments/{EXPERIMENT_NAME}



In [None]:
# Bucket definitions
VERSION = 'v02'
MODEL_NAME = 'finetune-en-de'
MODEL_DISPLAY_NAME = f'{MODEL_NAME}-{VERSION}'

# Staging bucket for Vertex AI
WORKSPACE = f'gs://{BUCKET}/{MODEL_DISPLAY_NAME}'

# Docker definitions for training


In [None]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M")

# Model dir to save logs, ckpts, etc.
MODEL_DIR = f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}/{TIMESTAMP}'

# Data dir to save the processed dataset
TFDS_DATA_DIR = f'gs://{BUCKET}/dataset/{MODEL_DISPLAY_NAME}'

# Gin file and run mode
GIN_FILE = 'small_finetune_wmt.gin'
GIN_FILE_GCS = f'gs://{BUCKET}/staging/{GIN_FILE}'
RUN_MODE = 'train'

In [None]:
# Copy gin file to GCS bucket and use the local GCSFuse mount
! gsutil cp {GIN_FILE} {GIN_FILE_GCS}
GIN_FILE_GCS = GIN_FILE_GCS.replace('gs://', '/gcs/')

#### Initialize Vertex AI client and log metadata

In [None]:
EXPERIMENT_ID = f'{MODEL_DISPLAY_NAME}-xl-2-partitions'
EXECUTION_NAME = f'execution-1'
RUN_NAME = f'run--{TIMESTAMP}'

In [None]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging',
    experiment=EXPERIMENT_ID
)

#### Define infra and submit job

In [None]:
MACHINE_TYPE = 'cloud-tpu'
ACCELERATOR_TYPE = 'TPU_V2'
ACCELERATOR_NUM = 32
REPLICA_COUNT = 1

In [None]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_NUM,
        },
        "replica_count": REPLICA_COUNT,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "args": [
                f'--run_mode={RUN_MODE}',
                f'--gin_file={GIN_FILE_GCS}',
                f'--gin.MODEL_DIR="{MODEL_DIR}"',
                f'--tfds_data_dir={TFDS_DATA_DIR}',
                '--gin.USE_CACHED_TASKS=False'
            ],
        },
    }
]

In [None]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  os.path.join(WORKSPACE, job_name)

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)
job.run(
    sync=False
)

In [None]:
cmd = f"""
tb-gcp-uploader --tensorboard_resource_name {TENSORBOARD_ID} \
--logdir {MODEL_DIR} \
--experiment_name {EXPERIMENT_ID}
"""

print(cmd)

In [None]:
vertex_ai.end_run()

Now open the URL presented in the output of this command and analyse the training and inference logs.

![Tensorboard](./images/tb-sample.png)

### Metric Explanations

By default, t5x logs many metrics to TensorBoard, many of these seem similar but
have important distinctions.

The first two graphs you will see are the `accuracy` and `cross_ent_loss`
graphs. These are the *token-level teacher-forced* accuracy and cross entropy
loss respectively. Each of these graphs can have multiple curves on them. The
first curve is the `train` curve. This is calculated as a running sum than is
then normalized over the whole training set. The second class of curves have the
form `training_eval/${task_name}`. These curves are created by running a subset
(controlled by the `eval_steps` parameter of the main train function) of the
validation split of `${task_name}` through the model and calculating these
metrics using teacher-forcing. These graphs can commonly be used to find
"failure to learn" cases and as a warning sign of overfitting, but these are
often not the final metrics one would report on.

The second set of graphs are the ones under the collapsible `eval` section in
TensorBoard. These graphs are created based on the `metric_fns` defined in the
SeqIO task. The curves on these graphs have the form
`inference_eval/${task_name}`. Values are calculated by running the whole
validation split through the model in inference mode, commonly auto-regressive
decoding or output scoring. Most likely these are the metrics that will be
reported.

More information about the configuration of the datasets used for these
different metrics can be found [here](#train-train-eval-and-infer-eval).

In summary, the metric you actually care about most likely lives under the
`eval` tab rather, than in the `accuracy` graph.

In [None]:
import t5x

In [None]:
import t5