In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Pre train a Translation Model with T5X on Vertex AI TPU Slice

### Step 4: Configure and launch a Vertex AI Training job to fine tune the model

In [1]:
import os
import time
from datetime import datetime
import pandas as pd

from google.cloud import aiplatform as vertex_ai

In [2]:
# Project definitions
PROJECT_ID = 'jk-mlops-dev' # Change to your project id.
REGION = 'us-central1'  # Change to your region.

# Bucket definitions
BUCKET = 'rl-mlops-language' # Change to your bucket.

In [3]:
# Bucket definitions
VERSION = 'v01'
MODEL_NAME = 'pretrain-en-de'
MODEL_DISPLAY_NAME = f'{MODEL_NAME}-{VERSION}'

# Staging bucket for Vertex AI
WORKSPACE = f'gs://{BUCKET}/{MODEL_DISPLAY_NAME}'

# Docker definitions for training
IMAGE_NAME = 't5x-base'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

In [4]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M")

# Model dir to save logs, ckpts, etc.
MODEL_DIR = f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}/{TIMESTAMP}'

# Data dir to save the processed dataset
TFDS_DATA_DIR = f'gs://{BUCKET}/dataset/{MODEL_DISPLAY_NAME}'

# Gin file and run mode
GIN_FILE = 'wmt19_ende_from_scratch.gin'
GIN_FILE_GCS = f'gs://{BUCKET}/staging/{GIN_FILE}'
RUN_MODE = 'train'

In [5]:
# Copy gin file to GCS bucket and use the local GCSFuse mount
! gsutil cp {GIN_FILE} {GIN_FILE_GCS}
GIN_FILE_GCS = GIN_FILE_GCS.replace('gs://', '/gcs/')

Copying file://wmt19_ende_from_scratch.gin [Content-Type=application/octet-stream]...
/ [1 files][  1.5 KiB/  1.5 KiB]                                                
Operation completed over 1 objects/1.5 KiB.                                      


#### Initialize Vertex AI client and log metadata

In [6]:
EXPERIMENT_ID = f'{MODEL_DISPLAY_NAME}-{TIMESTAMP}'
EXECUTION_NAME = f'execution-1'
RUN_NAME = 'run-1'

In [7]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging',
    experiment=EXPERIMENT_ID
)

In [8]:
vertex_ai.start_run(RUN_NAME)

Associating projects/895222332033/locations/us-central1/metadataStores/default/contexts/pretrain-en-de-v01-202207221528-run-1 to Experiment: pretrain-en-de-v01-202207221528


<google.cloud.aiplatform.metadata.experiment_run_resource.ExperimentRun at 0x7efc4dac0a10>

In [9]:
with vertex_ai.start_execution(
    schema_title="system.ContainerExecution", display_name=EXECUTION_NAME
) as execution:

    dataset_seqio_artifact = vertex_ai.Artifact.create(
        schema_title="system.Dataset", display_name='tfds_dataset', uri=TFDS_DATA_DIR
    )

    with open(GIN_FILE) as fp:
        gin_content = fp.read()

    gin_config_artifact = vertex_ai.Artifact.create(
        schema_title="system.Artifact", 
        display_name='gin_configuration_file', 
        uri=GIN_FILE_GCS.replace('/gcs/', 'gs://'),
        metadata= {
            'gin_file': gin_content
        }
    )

    model_artifact = vertex_ai.Artifact.create(
        schema_title="system.Model", display_name='wmt_pretrained_model', uri=MODEL_DIR
    )

    execution.assign_input_artifacts([dataset_seqio_artifact, gin_config_artifact])
    execution.assign_output_artifacts([model_artifact])

    vertex_ai.log_metrics(
        {"lineage": execution.get_output_artifacts()[0].lineage_console_uri}
    )

#### Define infra and submit job

In [10]:
MACHINE_TYPE = 'cloud-tpu'
ACCELERATOR_TYPE = 'TPU_V2'
ACCELERATOR_NUM = 32
REPLICA_COUNT = 1

In [11]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_NUM,
        },
        "replica_count": REPLICA_COUNT,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "args": [
                f'--run_mode={RUN_MODE}',
                f'--gin_file={GIN_FILE_GCS}',
                f'--gin.MODEL_DIR="{MODEL_DIR}"',
                f'--tfds_data_dir={TFDS_DATA_DIR}',
                '--gin.USE_CACHED_TASKS=False'
            ],
        },
    }
]

In [12]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  os.path.join(WORKSPACE, job_name)

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)

job.run(
    sync=False
)

Creating CustomJob


### Step 5: Explore metrics

After fine-tuning has completed, you can parse metrics into CSV format using the following script:

In [None]:
GCS_VAL_DIR=f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}/inference_eval/*'
VAL_DIR = './inference_eval'
OUTPUT_FILE = './results.csv'

In [None]:
! mkdir {VAL_DIR}
! gsutil -m cp -r {GCS_VAL_DIR} {VAL_DIR}

In [None]:
! python -m t5.scripts.parse_tb \
  --summary_dir={VAL_DIR} \
  --seqio_summaries \
  --out_file={OUTPUT_FILE} \
  --alsologtostderr

In [None]:
results = pd.read_csv('results.csv', sep=',')
results

In [None]:
metrics = {}
metrics['max_squad_em'] = results[-2:-1]['SQuAD (EM)'].values[0]
metrics['max_squad_f1'] = results[-2:-1]['SQuAD (F1)'].values[0]
metrics['step_squad_em'] = results[-1:]['SQuAD (EM)'].values[0]
metrics['step_squad_f1'] = results[-1:]['SQuAD (F1)'].values[0]
vertex_ai.log_metrics(metrics)

In [None]:
vertex_ai.end_run()

#### Analyse with Vertex AI Tensorboard

In [None]:
TENSORBOARD_INSTANCE_NAME = 't5x-analyse'
VALIDATION_TB_LOGS = f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}/inference_eval'
TRAINING_TB_LOGS = f'gs://{BUCKET}/model/{MODEL_DISPLAY_NAME}/training_eval'

In [None]:
! gcloud ai tensorboards create --display-name={TENSORBOARD_INSTANCE_NAME} --region={REGION} --project={PROJECT_ID}

In [None]:
tensorboard_id = ! gcloud ai tensorboards list --filter="displayName=t5x-analyse" --format="value(name)" --region=us-central1 --limit=1 

In [None]:
! tb-gcp-uploader --tensorboard_resource_name \
  {tensorboard_id[1]} \
  --logdir={VALIDATION_TB_LOGS} \
  --experiment_name={EXPERIMENT_ID} --one_shot=True

! tb-gcp-uploader --tensorboard_resource_name \
  {tensorboard_id[1]} \
  --logdir={TRAINING_TB_LOGS} \
  --experiment_name={EXPERIMENT_ID} --one_shot=True

Now open the URL presented in the output of this command and analyse the training and inference logs.

![Tensorboard](./images/tb-sample.png)

### Metric Explanations

By default, t5x logs many metrics to TensorBoard, many of these seem similar but
have important distinctions.

The first two graphs you will see are the `accuracy` and `cross_ent_loss`
graphs. These are the *token-level teacher-forced* accuracy and cross entropy
loss respectively. Each of these graphs can have multiple curves on them. The
first curve is the `train` curve. This is calculated as a running sum than is
then normalized over the whole training set. The second class of curves have the
form `training_eval/${task_name}`. These curves are created by running a subset
(controlled by the `eval_steps` parameter of the main train function) of the
validation split of `${task_name}` through the model and calculating these
metrics using teacher-forcing. These graphs can commonly be used to find
"failure to learn" cases and as a warning sign of overfitting, but these are
often not the final metrics one would report on.

The second set of graphs are the ones under the collapsible `eval` section in
TensorBoard. These graphs are created based on the `metric_fns` defined in the
SeqIO task. The curves on these graphs have the form
`inference_eval/${task_name}`. Values are calculated by running the whole
validation split through the model in inference mode, commonly auto-regressive
decoding or output scoring. Most likely these are the metrics that will be
reported.

More information about the configuration of the datasets used for these
different metrics can be found [here](#train-train-eval-and-infer-eval).

In summary, the metric you actually care about most likely lives under the
`eval` tab rather, than in the `accuracy` graph.