# Tensorflow on TPUs with Vertex AI

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mugglmenzel/ml-code-examples/blob/main/simple-ml-on-vertex-ai/notebooks/Tensorflow%20on%20TPUs.ipynb)
[![Open In Workbench](https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32)](https://console.cloud.google.com/ai-platform/notebooks/deploy-notebook?name=Tensorflow%2520on%2520GPUs&download_url=https%3A%2F%2Fraw.githubusercontent.com%2Fmugglmenzel%2Fml-code-examples%2Fmain%2Fsimple-ml-on-vertex-ai%2Fnotebooks%2FTensorflow%2520on%2520TPUs.ipynb)

Contributor: michaelmenzel@google.com

Disclaimer: This is a code example and not intended to be used in production. The author does not take any liability for the use of this code example.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Train a TF model locally with TPU acceleration

We start by running a simple training program with the MNIST dataset in Tensorflow:

In [None]:
import logging

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_datasets as tfds

resolver = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)
print('Devices used by strategy in training loop:', strategy.extended.worker_devices)

tf.get_logger().setLevel(logging.getLevelName('INFO'))
logging.basicConfig(level=logging.getLevelName('INFO'))
tf.debugging.set_log_device_placement(True)

(train_data, val_data), mnist_info = tfds.load("mnist",
                                               split=['train', 'test'], as_supervised=True,
                                               try_gcs=True, with_info=True)

@tf.function
def norm_data(image, label):
    return tf.cast(image, tf.float32) / 255., label

train_ds = (train_data
            .map(norm_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
            .batch(128, drop_remainder=True)
            .cache()
            .prefetch(tf.data.experimental.AUTOTUNE))
val_ds = (val_data
          .map(norm_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
          .batch(128, drop_remainder=True)
          .cache()
          .prefetch(tf.data.experimental.AUTOTUNE))

with strategy.scope():
  model = keras.Sequential([
          keras.layers.Reshape(target_shape=(28, 28, 1), input_shape=(28, 28)),
          keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding='same', activation='elu'),
          keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
          keras.layers.Conv2D(filters=128, kernel_size=(5, 5), padding='same', activation='elu'),
          keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
          keras.layers.Conv2D(filters=256, kernel_size=(5, 5), padding='same', activation='elu'),
          keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
          keras.layers.Flatten(),
          keras.layers.Dense(256, activation='elu'),
          keras.layers.Dense(10, activation='softmax')
      ])

  model.compile(optimizer='adam', jit_compile=True,
                loss='sparse_categorical_crossentropy',
                metrics=['sparse_categorical_accuracy'])

model.fit(train_ds, validation_data=val_ds, epochs=1)
model.evaluate(val_ds)
model.save('my_model',
           options=tf.saved_model.SaveOptions(
               experimental_io_device='/job:localhost'))

## Launch a TPU-accelerated Training on Vertex AI

In this part we launch a training program on Vertex AI and register the resulting model in the Vertex AI Model Registry.

In [None]:
#@title Install Vertex AI Python SDK
try:
    from google.cloud import aiplatform
except:
    !pip install --user -q google-cloud-aiplatform
    exit()

In [None]:
#@title Parameters
from datetime import datetime
from google.cloud import aiplatform

try:
    from google.colab import auth
    auth.authenticate_user()
except:
    print('Not on Colab.')

PROJECT_ID = 'sandbox-michael-menzel' #@param
STAGING_BUCKET='gs://sandbox-michael-menzel-training-europe-west4/trainings/mnist-distributed-vertex' #@param


TIMESTAMP = datetime.now().strftime("%Y%m%d%H%M%S")
EXPERIMENT = f'{PROJECT_ID}-mnist-pysdk'
JOB_NAME = f'{EXPERIMENT}-{TIMESTAMP}'


aiplatform.init(location='europe-west4', project=PROJECT_ID, experiment=EXPERIMENT)

Following we write the training script file:

In [None]:
%%writefile train.py
#@title Write the training script
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import argparse
import json
import logging
import math
import os
import sys
import time

import numpy as np

import tensorflow as tf
from tensorflow import keras
import tensorflow_datasets as tfds

logging.info(f"Using Tensorflow version {tf.__version__}")

import hypertune

hpt = hypertune.HyperTune()
recorder = {'previous': 0, 'steps': []}

def record(step, writer):
    previous = recorder['steps'][recorder['previous']]['time'] if recorder['previous'] < len(recorder['steps']) else time.time()
    current = time.time()
    logging.info(f"[{step}]: +{current - previous} sec ({current} UNIX)")
    with writer.as_default():
        tf.summary.scalar(step, current, step=0)
    hpt.report_hyperparameter_tuning_metric(
                hyperparameter_metric_tag=step,
                metric_value=current)
    recorder['previous'] = len(recorder['steps']) - 1
    recorder['steps'].append({'name': step, 'time': current})


def summarize_recorder():
    logging.info("Summary of processing steps (in seconds):")
    previous = 0
    for step in recorder['steps']:
        logging.info(f"  Step: {step['name']}, Time: {step['time']}, Duration: {step['time'] - previous}")
        previous = step['time']


class LossReporterCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs=None):
        if logs:
            print(f"loss: {logs['loss']} in epoch: {epoch}")
            tf.summary.scalar('loss', logs['loss'], step=epoch)
            hpt.report_hyperparameter_tuning_metric(
                hyperparameter_metric_tag='loss',
                metric_value=logs['loss'],
                global_step=epoch)


def _is_chief(strategy):
    task_type = strategy.cluster_resolver.task_type
    return task_type == 'chief' or task_type is None


def _model_save_path(strategy):
    if strategy.cluster_resolver:
        task_type = strategy.cluster_resolver.task_type
        task_id = strategy.cluster_resolver.task_id
        subfolder = () if _is_chief(strategy) else (str(task_type), str(task_id))
    else:
        subfolder = ()
    return os.path.join(os.environ['AIP_MODEL_DIR'], *subfolder)


def _compile_model(strategy):
    model = keras.Sequential([
        keras.layers.Reshape(target_shape=(28, 28, 1), input_shape=(28, 28)),
        keras.layers.Conv2D(filters=64, kernel_size=(5, 5), padding='same', activation='elu'),
        keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
        keras.layers.Conv2D(filters=128, kernel_size=(5, 5), padding='same', activation='elu'),
        keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
        keras.layers.Conv2D(filters=256, kernel_size=(5, 5), padding='same', activation='elu'),
        keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2)),
        keras.layers.Flatten(),
        keras.layers.Dense(256, activation='elu'),
        keras.layers.Dense(10, activation='softmax')
    ])

    optimizer_config = {
        'class_name': 'adam',
        'config': {
            'learning_rate': params.learning_rate
        }
    }
    optimizer = tf.keras.optimizers.get(optimizer_config)


    model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['sparse_categorical_accuracy'])
    return model

def _train(params, strategy, writer):
    num_workers = strategy.num_replicas_in_sync or 1

    TRAIN_BATCH_SIZE = params.batch_size * num_workers
    VAL_BATCH_SIZE = params.batch_size * num_workers
    logging.info(f"Running with {TRAIN_BATCH_SIZE} train batch size and {VAL_BATCH_SIZE} validation batch size.")

    (train_data, val_data), mnist_info = tfds.load("mnist",
                                                   try_gcs=True,
                                                   with_info=True,
                                                   split=['train', 'test'],
                                                   as_supervised=True)

    @tf.function
    def norm_data(image, label):
        return tf.cast(image, tf.float32) / 255., label

    TRAIN_STEPS_EPOCH = int(mnist_info.splits['train'].num_examples // TRAIN_BATCH_SIZE)
    VAL_STEPS_EPOCH = int(mnist_info.splits['test'].num_examples // VAL_BATCH_SIZE)
    logging.info(f"Running with {TRAIN_STEPS_EPOCH} train steps and {VAL_STEPS_EPOCH} validation steps.")

    ds_options = tf.data.Options()
    ds_options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF

    train_ds = (train_data
                .with_options(ds_options)
                .map(norm_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
                .batch(TRAIN_BATCH_SIZE, drop_remainder=True)
                .cache()
                .repeat(params.num_epochs)
                .prefetch(tf.data.experimental.AUTOTUNE))
    val_ds = (val_data
              .with_options(ds_options)
              .map(norm_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
              .batch(VAL_BATCH_SIZE, drop_remainder=True)
              .cache()
              .repeat(params.num_epochs)
              .prefetch(tf.data.experimental.AUTOTUNE))
    record('dataset_ready', writer)

    with strategy.scope():
        model = _compile_model(strategy)

    model.summary()
    record('model_ready', writer)

    model.fit(train_ds, validation_data=val_ds,
              steps_per_epoch=TRAIN_STEPS_EPOCH, validation_steps=VAL_STEPS_EPOCH,
              epochs=params.num_epochs,
              callbacks=[
                  LossReporterCallback(),
                  tf.keras.callbacks.TensorBoard(os.environ['AIP_TENSORBOARD_LOG_DIR'], profile_batch=0)
              ])
#tf.keras.callbacks.experimental.BackupAndRestore(os.path.join(params.job_dir, 'backups'))
    record('model_trained', writer)

    model_save_path = _model_save_path(strategy)
    logging.info(f'Saving model to {model_save_path}.')
    model.save(model_save_path)
    record('model_saved', writer)

    logging.info('Model training complete.')
    record('done', writer)

    logging.info(params)
    summarize_recorder()


def _get_args():
    """Argument parser.
    Returns:
    Dictionary of arguments.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--num-epochs',
        type=int,
        default=10,
        help='number of times to go through the data, default=5')
    parser.add_argument(
        '--batch-size',
        default=100,
        type=int,
        help='number of records to read during each training step, default=128')
    parser.add_argument(
        '--learning-rate',
        default=.01,
        type=float,
        help='learning rate for optimizer, default=.01')
    parser.add_argument(
        '--long-runner',
        default='False',
        type=str,
        help='long running job indicator, default=False')
    parser.add_argument(
        '--verbosity',
        choices=['DEBUG', 'ERROR', 'FATAL', 'INFO', 'WARN'],
        default='DEBUG')
    return parser.parse_args()

def _detect_strategy():
    strategy = None
    try:
        logging.info('TPU_CONFIG:' + str(os.environ.get('TPU_CONFIG')))
        logging.info('TF_CONFIG:' + str(os.environ.get('TF_CONFIG')))
        tf_config = json.loads(os.environ.get('TF_CONFIG')) if os.environ.get('TF_CONFIG') else None
        tpu_config = json.loads(os.environ.get('TPU_CONFIG')) if os.environ.get('TPU_CONFIG') else None
        tf_cluster = tf_config['cluster'] if tf_config and 'cluster' in tf_config else {}
        worker_count = len(tf_cluster['worker']) if tf_cluster and 'worker' in tf_cluster else 0

        if tpu_config:
            resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='local')
            tf.config.experimental_connect_to_cluster(resolver)
            tf.tpu.experimental.initialize_tpu_system(resolver)
            strategy = tf.distribute.TPUStrategy(resolver)
        elif worker_count > 0:
            strategy = tf.distribute.MultiWorkerMirroredStrategy()
        else:
            strategy = tf.distribute.MirroredStrategy()
    except Exception as e:
        logging.error('Could not detect TF and TPU configuration.' + str(e))

    return strategy


def _fix_os_vars():
    if not 'AIP_TENSORBOARD_LOG_DIR' in os.environ:
        os.environ['AIP_TENSORBOARD_LOG_DIR'] = os.environ['AIP_MODEL_DIR']

if __name__ == "__main__":
    params = _get_args()
    if params:
        tf.get_logger().setLevel(logging.getLevelName(params.verbosity))
        logging.basicConfig(level=logging.getLevelName(params.verbosity))

    strategy = _detect_strategy()
    _fix_os_vars()

    if params and strategy:
        writer = tf.summary.create_file_writer(os.environ['AIP_TENSORBOARD_LOG_DIR'])
        record('program_start', writer)
        logging.info(f'Running training program with strategy:{strategy}')
        _train(params, strategy, writer)
    else:
        logging.error('Could not parse parameters and configuration.')


With training script file we can now launch a training job which uses GPUs and registers the resulting model:

In [None]:
vertex_ai_custom_job = aiplatform.CustomTrainingJob(
    display_name=JOB_NAME,
    script_path='train.py',
    container_uri='europe-docker.pkg.dev/vertex-ai/training/tf-tpu.2-8:latest',
    requirements=['cloudml-hypertune', 'tensorflow-datasets'],
    model_serving_container_image_uri='europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest',
    model_description='GPU-accelerated MNIST model',
    staging_bucket=STAGING_BUCKET
)

vertex_ai_custom_job.run(
    machine_type='cloud-tpu',
    replica_count=1,
    accelerator_type = 'TPU_V2',
    accelerator_count = 8,
    args=['--num-epochs=20'],
    sync=True
)


## Launch a TPU-accelerated, Container-based Training on Vertex AI

In this part we build and publish a container image with our training script to Artifact Registry and launch a container-based training job on Vertex AI Training.

In [None]:
%%writefile Dockerfile
#@title Write the Dockerfile
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

FROM python:3.8-slim

RUN apt update && apt install -y wget

RUN wget -q https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/tensorflow/tf-2.10.0/tensorflow-2.10.0-cp38-cp38-linux_x86_64.whl
RUN pip3 install tensorflow-2.10.0-cp38-cp38-linux_x86_64.whl
RUN rm tensorflow-2.10.0-cp38-cp38-linux_x86_64.whl

RUN wget -q https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/libtpu/1.4.0/libtpu.so -O /lib/libtpu.so
RUN chmod 777 /lib/libtpu.so

RUN pip install cloudml-hypertune tensorflow-datasets

ENV PYTHONUNBUFFERED="true"

COPY train.py /trainer/

ENTRYPOINT ["python3", "/trainer/train.py"]

We build the container and publish it to Artifact Registry:

In [None]:
!gcloud builds submit --tag=eu.gcr.io/$PROJECT_ID/mnist-trainer:$TIMESTAMP-tpu --project=$PROJECT_ID --region=europe-west4

Now we can use the container to launch a training job and register the resulting model in Vertex AI:

In [None]:
vertex_ai_custom_container_job = aiplatform.CustomContainerTrainingJob(
    display_name=JOB_NAME,
    container_uri=f'eu.gcr.io/{PROJECT_ID}/mnist-trainer:{TIMESTAMP}-tpu',
    model_serving_container_image_uri='europe-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest',
    model_description='GPU-accelerated MNIST model',
    staging_bucket=STAGING_BUCKET
)

vertex_ai_custom_container_job.run(
    machine_type='cloud-tpu',
    replica_count=1,
    accelerator_type = 'TPU_V2',
    accelerator_count = 8,
    args=['--num-epochs=20'],
    sync=True
)
