In [None]:
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Running inference on a Model


## Introduction

This page outlines the steps to run inference a model with T5X on Tasks/Mixtures
defined with [SeqIO](https://github.com/google/seqio/blob/main/README.md).

## Overview

Running inference on a model with T5X using SeqIO Task/Mixtures consists of the
following steps:

1.  Choose the model to run inference on.
1.  Choose the SeqIO Task/Mixture to run inference on.
1.  Write a Gin file that configures the model, SeqIO Task/Mixture and other
    details of your inference run.
1.  Launch your experiment on Vertex AI.
1.  Monitor your experiment and access predictions.

These steps are explained in detail in the following sections.

### Configure and launch a Vertex AI Training job to fine tune the model

In [1]:
import os
import time
from datetime import datetime

from google.cloud import aiplatform as vertex_ai

In [2]:
# Project definitions
PROJECT_ID = 'renatoleite-dev' # Change to your project id.
REGION = 'us-central1'  # Change to your region.

# Bucket definitions
BUCKET = 'rl-language' # Change to your bucket.

In [3]:
# Bucket definitions
VERSION = 'v01'
INFER_NAME = 'infer-wmt-en-de'
INFER_DISPLAY_NAME = f'{INFER_NAME}-{VERSION}'

# Staging bucket for Vertex AI
WORKSPACE = f'gs://{BUCKET}/{INFER_DISPLAY_NAME}'

# Docker definitions for training
IMAGE_NAME = 't5x-base'
IMAGE_URI = f'gcr.io/{PROJECT_ID}/{IMAGE_NAME}'

In [4]:
TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M")

# Model dir to save logs, ckpts, etc.
INFER_DIR = f'gs://{BUCKET}/infer/{INFER_DISPLAY_NAME}/{TIMESTAMP}'

# Data dir to save the processed dataset
TFDS_DATA_DIR = f'gs://{BUCKET}/dataset/{INFER_DISPLAY_NAME}'

# Gin file and run mode
GIN_FILE = 'base_wmt_infer.gin'
GIN_FILE_GCS = f'gs://{BUCKET}/staging/{GIN_FILE}'
RUN_MODE = 'infer'

In [5]:
# Copy gin file to GCS bucket and use the local GCSFuse mount
! gsutil cp {GIN_FILE} {GIN_FILE_GCS}
GIN_FILE_GCS = GIN_FILE_GCS.replace('gs://', '/gcs/')

Copying file://base_wmt_infer.gin [Content-Type=application/octet-stream]...
/ [1 files][  578.0 B/  578.0 B]                                                
Operation completed over 1 objects/578.0 B.                                      


#### Initialize Vertex AI client and log metadata

In [6]:
EXPERIMENT_ID = f'{INFER_DISPLAY_NAME}-{TIMESTAMP}'
EXECUTION_NAME = f'execution-1'
RUN_NAME = 'run-1'

In [7]:
vertex_ai.init(
    project=PROJECT_ID,
    location=REGION,
    staging_bucket=f'gs://{BUCKET}/staging',
    experiment=EXPERIMENT_ID
)

In [8]:
vertex_ai.start_run(RUN_NAME)

Associating projects/375468928805/locations/us-central1/metadataStores/default/contexts/infer-wmt-en-de-v01-202207221327-run-1 to Experiment: infer-wmt-en-de-v01-202207221327


<google.cloud.aiplatform.metadata.experiment_run_resource.ExperimentRun at 0x7f21b5a1e550>

In [9]:
with vertex_ai.start_execution(
    schema_title="system.ContainerExecution", display_name=EXECUTION_NAME
) as execution:

    dataset_seqio_artifact = vertex_ai.Artifact.create(
        schema_title="system.Dataset", 
        display_name='infer_dataset',
        metadata= {
            'split': 'test',
            'task': 'wmt_t2t_translate'
        }
    )

    with open(GIN_FILE) as fp:
        gin_content = fp.read()

    gin_config_artifact = vertex_ai.Artifact.create(
        schema_title="system.Artifact", 
        display_name='gin_configuration_file', 
        uri=GIN_FILE_GCS.replace('/gcs/', 'gs://'),
        metadata= {
            'gin_file': gin_content
        }
    )

    infer_artifact = vertex_ai.Artifact.create(
        schema_title="system.Artifact", 
        display_name='infer_wmt_finetuned_model', 
        uri=INFER_DIR
    )

    execution.assign_input_artifacts([dataset_seqio_artifact, gin_config_artifact])
    execution.assign_output_artifacts([infer_artifact])

    vertex_ai.log_metrics(
        {"lineage": execution.get_output_artifacts()[0].lineage_console_uri}
    )

#### Define infra and submit job

In [10]:
MACHINE_TYPE = 'cloud-tpu'
ACCELERATOR_TYPE = 'TPU_V2'
ACCELERATOR_NUM = 8
REPLICA_COUNT = 1

In [11]:
worker_pool_specs =  [
    {
        "machine_spec": {
            "machine_type": MACHINE_TYPE,
            "accelerator_type": ACCELERATOR_TYPE,
            "accelerator_count": ACCELERATOR_NUM,
        },
        "replica_count": REPLICA_COUNT,
        "container_spec": {
            "image_uri": IMAGE_URI,
            "args": [
                f'--run_mode={RUN_MODE}',
                f'--gin_file={GIN_FILE_GCS}',
                f'--gin.INFER_OUTPUT_DIR="{INFER_DIR}"',
                f'--tfds_data_dir={TFDS_DATA_DIR}',
                '--gin.USE_CACHED_TASKS=False'
            ],
        },
    }
]

In [12]:
job_name = 't5x_{}'.format(time.strftime("%Y%m%d_%H%M%S"))
base_output_dir =  os.path.join(WORKSPACE, job_name)

job = vertex_ai.CustomJob(
    display_name=job_name,
    worker_pool_specs=worker_pool_specs,
    base_output_dir=base_output_dir
)

job.run(
    sync=False
)

Creating CustomJob
