In [None]:
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))

In [None]:
import tfx
print('TFX version: {}'.format(tfx.__version__))
import tensorflow_data_validation as tfdv
print('TFDV version: {}'.format(tfdv.__version__))

In [None]:
import kfp
print('KFP version: {}'.format(kfp.__version__))

In [None]:
GOOGLE_CLOUD_PROJECT='YOUR PORJECT'
GOOGLE_CLOUD_REGION='us-central1'
GCS_BUCKET_NAME='YOUR BUCKET'
# The data directory contains the tfrecords prepared in https://betterprogramming.pub/a-step-by-step-guide-to-train-a-model-on-google-clouds-vertex-ai-47faafae1330
DATA_ROOT = 'DATA LOCATION'

In [None]:
PIPELINE_NAME = 'cifar10'

# Path to various pipeline artifact.
PIPELINE_ROOT = 'gs://{}/pipeline_root/{}'.format(
    GCS_BUCKET_NAME, PIPELINE_NAME)

# Paths for users' Python module.
MODULE_ROOT = 'gs://{}/pipeline_module/{}'.format(
    GCS_BUCKET_NAME, PIPELINE_NAME)

# The golden schema comes from a previous run.
GOLDEN_SCHEMA = 'gs://{}/pipeline_root/{}/schema'.format(
    GCS_BUCKET_NAME, PIPELINE_NAME)

# This is the path where your model will be pushed for serving.
SERVING_MODEL_DIR = 'gs://{}/serving_model/{}'.format(
    GCS_BUCKET_NAME, PIPELINE_NAME)

# Metadata is only used for local processing.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')

In [None]:
_trainer_module_file = 'cifar10_tfx.py'

In [None]:
%%writefile {_trainer_module_file}

# Copied from https://www.tensorflow.org/tfx/tutorials/tfx/penguin_simple

from typing import List
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_transform.tf_metadata import schema_utils


from tfx import v1 as tfx
import tensorflow_data_validation as tfdv
from tfx_bsl.public import tfxio

from tensorflow_metadata.proto.v0 import schema_pb2

_BATCH_SIZE = 32
_EPOCH = 15


def _input_fn(file_pattern: List[str],
              data_accessor: tfx.components.DataAccessor,
              schema: schema_pb2.Schema,
              batch_size: int) -> tf.data.Dataset:
  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, shuffle=True, shuffle_buffer_size=1000, label_key='label'),
      schema=schema).repeat()


def _make_keras_model() -> tf.keras.Model:
  model = tf.keras.models.Sequential()
  model.add(tf.keras.layers.Reshape((32, 32, 3), input_shape=(3072,)))
  model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu'))
  model.add(tf.keras.layers.MaxPooling2D((2, 2)))
  model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
  model.add(tf.keras.layers.MaxPooling2D((2, 2)))
  model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))
  model.add(tf.keras.layers.Flatten())
  model.add(tf.keras.layers.Dropout(0.4))
  model.add(tf.keras.layers.Dense(64, activation='relu'))
  model.add(tf.keras.layers.Dense(10, activation='softmax'))

  model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['accuracy'])
  model.summary(print_fn=logging.info)
  return model


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  schema = tfdv.load_schema_text(input_path=fn_args.schema_path)

  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      schema=schema,
      batch_size=_BATCH_SIZE)
  eval_dataset = _input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      schema=schema,
      batch_size=_BATCH_SIZE)

  model = _make_keras_model()
  model.fit(
      train_dataset,
      epochs=_EPOCH,
      steps_per_epoch=100,
      validation_steps=20,
      validation_data=eval_dataset)

  model.save(fn_args.serving_model_dir, save_format='tf')

In [None]:
!gsutil cp {_trainer_module_file} {MODULE_ROOT}/

In [None]:
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     module_file: str, serving_model_dir: str, golden_schema: str, metadata_path: str=''
                     ) -> tfx.v1.dsl.Pipeline:
  example_gen = tfx.components.ImportExampleGen(input_base=data_root,
                                               input_config=tfx.proto.example_gen_pb2.Input(splits=[
                                                   tfx.proto.example_gen_pb2.Input.Split(name='train', pattern='train.tfrecord'),
                                                   tfx.proto.example_gen_pb2.Input.Split(name='eval', pattern='val.tfrecord')
                                               ]))

  statistics_gen = tfx.components.StatisticsGen(examples=example_gen.outputs['examples'])
  schema_gen = tfx.components.SchemaGen(statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
  schema_importer = tfx.v1.dsl.Importer(source_uri=golden_schema,
                                    artifact_type=tfx.types.standard_artifacts.Schema).with_id('schema_importer')

  example_validator = tfx.components.ExampleValidator(statistics=statistics_gen.outputs['statistics'],
                                                     schema=schema_importer.outputs['result'])

  trainer = tfx.components.Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],
      schema=schema_gen.outputs['schema'],
      train_args=tfx.proto.trainer_pb2.TrainArgs(splits=['train']),
      eval_args=tfx.proto.trainer_pb2.EvalArgs(splits=['eval']))

  components = [
      example_gen,
      statistics_gen,
      schema_gen,
      schema_importer,
      example_validator,
      trainer,
      #pusher,
  ]

  return tfx.v1.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      # Metadata config is only used for local processing.
      #metadata_connection_config=tfx.orchestration.metadata
      #                              .sqlite_metadata_connection_config(metadata_path),
      components=components)

In [None]:
# Process on Vertex AI
import os

PIPELINE_DEFINITION_FILE = PIPELINE_NAME + '_pipeline.json'

runner = tfx.v1.orchestration.experimental.KubeflowV2DagRunner(
    config=tfx.v1.orchestration.experimental.KubeflowV2DagRunnerConfig(),
    output_filename=PIPELINE_DEFINITION_FILE)
_ = runner.run(
    _create_pipeline(
        pipeline_name=PIPELINE_NAME,
        pipeline_root=PIPELINE_ROOT,
        data_root=DATA_ROOT,
        module_file=os.path.join(MODULE_ROOT, _trainer_module_file),
        serving_model_dir=SERVING_MODEL_DIR,
        golden_schema=GOLDEN_SCHEMA))

from kfp.v2.google import client

pipelines_client = client.AIPlatformClient(
    project_id=GOOGLE_CLOUD_PROJECT,
    region=GOOGLE_CLOUD_REGION,
)

_ = pipelines_client.create_run_from_job_spec(PIPELINE_DEFINITION_FILE)

In [None]:
# Process locally.
tfx.v1.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      module_file=_trainer_module_file,
      serving_model_dir=SERVING_MODEL_DIR,
      golden_schema=GOLDEN_SCHEMA,
      metadata_path=METADATA_PATH))