In [1]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Vertex Training with TPU pods

In [1]:
import os
import time

from google.cloud import aiplatform as vertex_ai

In [2]:
PROJECT_ID = 'jk-mlops-dev'
BUCKET = 'gs://jk-t5x-staging'
REGION = 'us-central1'

In [3]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging',
)

In [4]:
IMAGE_URI = f'gcr.io/{PROJECT_ID}/jax_tpu_test'

In [5]:
!docker build -t {IMAGE_URI} .
!docker push {IMAGE_URI}

Sending build context to Docker daemon  776.7kB
Step 1/9 : FROM gcr.io/deeplearning-platform-release/base-cpu
 ---> 234ba2bc2b77
Step 2/9 : WORKDIR /llm
 ---> Using cache
 ---> 24fa0638d711
Step 3/9 : RUN conda create -n t5x python=3.9
 ---> Using cache
 ---> a167478b4738
Step 4/9 : RUN git clone --branch=main https://github.com/google-research/t5x
 ---> Using cache
 ---> a61661948b8b
Step 5/9 : RUN /opt/conda/envs/t5x/bin/pip install --upgrade pip
 ---> Using cache
 ---> 86cb1ebe0f8f
Step 6/9 : RUN /opt/conda/envs/t5x/bin/pip install -e 't5x[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
 ---> Using cache
 ---> 9b7ef795b3b7
Step 7/9 : RUN /opt/conda/envs/t5x/bin/pip uninstall seqio seqio-nightly -y
 ---> Using cache
 ---> 1cf74a384619
Step 8/9 : RUN /opt/conda/envs/t5x/bin/pip install seqio-nightly
 ---> Using cache
 ---> 3d834608198a
Step 9/9 : ADD test.py ./
 ---> Using cache
 ---> e3896659d3b7
Successfully built e3896659d3b7
Successfully tagged gcr.io/jk-

In [6]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": "cloud-tpu",
            "accelerator_type": "TPU_V2",
            "accelerator_count": 32,
        },
        "replica_count": 1,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "command": ["/opt/conda/envs/t5x/bin/python", "/llm/test.py"],
            #"args": [
            #    f'--gin_file={GIN_FILE}',
            #    f'--gin.MODEL_DIR="{MODEL_DIR}"',
            #    f'--tfds_data_dir={TFDS_DATA_DIR}',
            #    '--gin.USE_CACHED_TASKS=False'
            #],
        },
    }
]

worker_pool_specs

[{'machine_spec': {'machine_type': 'cloud-tpu',
   'accelerator_type': 'TPU_V2',
   'accelerator_count': 32},
  'replica_count': 1,
  'container_spec': {'image_uri': 'gcr.io/jk-mlops-dev/jax_tpu_test',
   'command': ['/opt/conda/envs/t5x/bin/python', '/llm/test.py']}}]

In [7]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  f'{BUCKET}/{job_name}'

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)

job.run(
    sync=True
)

Creating CustomJob
CustomJob created. Resource name: projects/895222332033/locations/us-central1/customJobs/6802329071230386176
To use this CustomJob in another session:
custom_job = aiplatform.CustomJob.get('projects/895222332033/locations/us-central1/customJobs/6802329071230386176')
View Custom Job:
https://console.cloud.google.com/ai/platform/locations/us-central1/training/6802329071230386176?project=895222332033
CustomJob projects/895222332033/locations/us-central1/customJobs/6802329071230386176 current state:
JobState.JOB_STATE_PENDING
CustomJob projects/895222332033/locations/us-central1/customJobs/6802329071230386176 current state:
JobState.JOB_STATE_PENDING
CustomJob projects/895222332033/locations/us-central1/customJobs/6802329071230386176 current state:
JobState.JOB_STATE_PENDING
CustomJob projects/895222332033/locations/us-central1/customJobs/6802329071230386176 current state:
JobState.JOB_STATE_PENDING
CustomJob projects/895222332033/locations/us-central1/customJobs/6802329