In [None]:
import json
import os
import time

from google.cloud import aiplatform as vertex_ai

In [None]:
# Project definitions
PROJECT_ID = 'renatoleite-dev' # Change to your project id.
REGION = 'us-central1'  # Change to your region.

# Bucket definitions
BUCKET = 'rl-language' # Change to your bucket.

In [None]:
# Bucket definitions
VERSION = 'v01'
MODEL_NAME = 't5-en-de'
MODEL_DISPLAY_NAME = f'{MODEL_NAME}-{VERSION}'
WORKSPACE = f'gs://{BUCKET}/{MODEL_DISPLAY_NAME}'

# Docker definitions for training
IMAGE_NAME = 't5x-training'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

In [None]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging'
)

In [None]:
! gcloud builds submit --tag {IMAGE_URI} --timeout=2h

In [None]:
MACHINE_TYPE = 'cloud-tpu'
ACCELERATOR_TYPE = 'TPU_V3'
ACCELERATOR_NUM = 8
REPLICA_COUNT = 1

In [None]:
# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR = f'gs://{BUCKET}/model'

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR = f'gs://{BUCKET}/dataset'
GIN_FILE = 't5x/examples/t5/t5_1_1/examples/base_wmt_from_scratch.gin'

In [None]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_NUM,
        },
        "replica_count": REPLICA_COUNT,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "command": ["/opt/conda/envs/t5x/bin/python", "/llm/t5x/t5x/train.py"],
            "args": [
                f'--gin_file={GIN_FILE}',
                f'--gin.MODEL_DIR="{MODEL_DIR}"',
                f'--tfds_data_dir={TFDS_DATA_DIR}'
            ],
        },
    }
]

In [None]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  os.path.join(WORKSPACE, job_name)

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)
job.run(
    sync=False
)