# TFX - Interactive Training Pipeline

The purpose of this notebook is to interactively run the following TFX pipeline steps:
1. Receive hyperparameters using hyperparam_gen custom python component
2. Extract data from BigQuery using BigQueryExampleGen
3. Validate the raw data using StatisticsGen and ExampleValidator
4. Process the data using Transform
5. Train a custom model using Trainer
6. Train an AutoML Tables model using automl_trainer custom python component
7. Evaluat the custom model using ModelEvaluator
8. Validate the custom model against the AutoML Tables model using a custom python component
7. Save the blessed to model registry location using using Pusher
8. Upload the model to AI Platform using aip_model_pusher custom python component

The custom components are implemented in the [tfx_pipeline/components.py](tfx_pipeline/components) module.

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import numpy as np
import tfx
import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_data_validation as tfdv
import tensorflow_model_analysis as tfma
from tensorflow_transform.tf_metadata import schema_utils
import logging

from model_src import data, features
from tfx_pipeline import components

logging.getLogger().setLevel(logging.INFO)

print("Tensorflow Version:", tf.__version__)

In [None]:
PROJECT = 'ksalama-cloudml'
REGION = 'us-central1'
BUCKET = 'ksalama-cloudml-us'

DATASET_DISPLAYNAME = 'chicago_taxi_tips'
CUSTOM_MODEL_DISPLAYNAME = f'{DATASET_DISPLAYNAME}_classifier_custom'
AUTOML_MODEL_DISPLAYNAME = f'{DATASET_DISPLAYNAME}_classifier_automl'

WORKSPACE = f"gs://{BUCKET}/ucaip_demo/chicago_taxi/pipelines_interactive"
RAW_SCHEMA_DIR = 'model_src/raw_schema'

MLMD_SQLLITE = 'mlmd.sqllite'
ARTIFACT_STORE = os.path.join(WORKSPACE, 'tfx_artifacts')
MODEL_REGISTRY = os.path.join(WORKSPACE, 'model_registry')
PIPELINE_NAME = f'{DATASET_DISPLAYNAME}_training_pipeline'
PIPELINE_ROOT = os.path.join(ARTIFACT_STORE, PIPELINE_NAME)

!gcloud config set project $PROJECT

## Create Interactive Context

In [None]:
CLEAN_ARTIFACTS = True
if tf.io.gfile.exists(ARTIFACT_STORE) and CLEAN_ARTIFACTS:
    print("Removing previous artifacts...")
    tf.io.gfile.rmtree(ARTIFACT_STORE)
    
if tf.io.gfile.exists(MLMD_SQLLITE) and CLEAN_ARTIFACTS:
    print("Deleting previous mlmd.sqllite...")
    tf.io.gfile.rmtree(ARTIFACT_STORE)

if not tf.io.gfile.exists(ARTIFACT_STORE):
    print("Creating local tfx artifact directory...")
    tf.io.gfile.mkdir(ARTIFACT_STORE)
    
print(f'Pipeline artifacts directory: {PIPELINE_ROOT}')
print(f'Local metadata SQLlit path: {MLMD_SQLLITE}')

In [None]:
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.sqlite.filename_uri = MLMD_SQLLITE
connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE
mlmd_store = mlmd.metadata_store.MetadataStore(connection_config)

context = InteractiveContext(
  pipeline_name=PIPELINE_NAME,
  pipeline_root=PIPELINE_ROOT,
  metadata_connection_config=connection_config
)

## 1. Hyperparameter Generation

In [None]:
hyperparams_gen = components.hyperparameters_gen(
    num_epochs=5,
    learning_rate=0.001,
    batch_size=512,
    hidden_units='64,64',
)

context.run(hyperparams_gen, enable_cache=False)

In [None]:
json.load(
    tf.io.gfile.GFile(
        os.path.join(
            hyperparams_gen.outputs.hyperparameters.get()[0].uri, 'hyperparameters.json')
    )
)

## 2. Data Extraction

In [None]:
from utils import datasource_utils
from tfx.extensions.google_cloud_big_query.example_gen.component import BigQueryExampleGen
from tfx.proto import example_gen_pb2, transform_pb2

### Extract train and eval splits

In [None]:
sql_query = datasource_utils.get_source_query(
    PROJECT, REGION, DATASET_DISPLAYNAME, data_split='UNASSIGNED', limit=10000)

output_config = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(
        splits=[
            example_gen_pb2.SplitConfig.Split(name="train", hash_buckets=4),
            example_gen_pb2.SplitConfig.Split(name="eval", hash_buckets=1),
        ]
    )
)

train_example_gen = BigQueryExampleGen(query=sql_query, output_config=output_config)

beam_pipeline_args=[
    f"--project={PROJECT}",
    f"--temp_location=gs://{BUCKET}/bq_tmp"
]

context.run(
    train_example_gen,
    beam_pipeline_args=beam_pipeline_args,
    enable_cache=False
)

### Extract test split

In [None]:
sql_query = datasource_utils.get_source_query(
    PROJECT, REGION, DATASET_DISPLAYNAME, data_split='TEST', limit=1000)

output_config = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(
        splits=[
            example_gen_pb2.SplitConfig.Split(name="test", hash_buckets=1),
        ]
    )
)

test_example_gen = BigQueryExampleGen(query=sql_query, output_config=output_config)

beam_pipeline_args=[
    f"--project={PROJECT}",
    f"--temp_location=gs://{BUCKET}/bq_tmp"
]

context.run(
    test_example_gen,
    beam_pipeline_args=beam_pipeline_args,
    enable_cache=False
)

In [None]:
train_uri = os.path.join(train_example_gen.outputs.examples.get()[0].uri, "train/*")
print(train_uri)

source_raw_schema = tfdv.load_schema_text(os.path.join(RAW_SCHEMA_DIR, 'schema.pbtxt'))
raw_feature_spec = schema_utils.schema_as_feature_spec(source_raw_schema).feature_spec

def _parse_tf_example(tfrecord):
    return tf.io.parse_single_example(tfrecord, raw_feature_spec)

tfrecord_filenames = tf.data.Dataset.list_files(train_uri)
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
dataset = dataset.map(_parse_tf_example)

for raw_features in dataset.shuffle(1000).batch(3).take(1):
    for key in raw_features:
        print(f"{key}: {np.squeeze(raw_features[key], -1)}")
    print("")

## 3. Data Validation

### Import raw schema

In [None]:
schema_importer = tfx.components.common_nodes.importer_node.ImporterNode(
    instance_name='Schema_Importer',
    source_uri=RAW_SCHEMA_DIR,
    artifact_type=tfx.types.standard_artifacts.Schema,
    reimport=False
)

context.run(schema_importer)

### Generate statistics

In [None]:
statistics_gen = tfx.components.StatisticsGen(
    instance_name='Statistics_Generation',
    examples=train_example_gen.outputs.examples)
context.run(statistics_gen)

In [None]:
!rm -r {RAW_SCHEMA_DIR}/.ipynb_checkpoints/

### Validate statistics against schema

In [None]:
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs.statistics,
    schema=schema_importer.outputs.result,
    instance_name="Data_Validation"
)

context.run(example_validator)

In [None]:
context.show(example_validator.outputs.anomalies)

## 4. Data Transformation

In [None]:
_transform_module_file = 'model_src/preprocessing.py'

transform = tfx.components.Transform(
    examples=train_example_gen.outputs.examples,
    schema=schema_importer.outputs.result,
    module_file=_transform_module_file,
    splits_config=transform_pb2.SplitsConfig(
        analyze=['train'], transform=['train', 'eval']),
    instance_name="Data_Transformation"
)

context.run(transform, enable_cache=False)

In [None]:
train_uri = os.path.join(transform.outputs.transformed_examples.get()[0].uri, "train/*")
transform_graph_uri = transform.outputs.transform_graph.get()[0].uri

tft_output = tft.TFTransformOutput(transform_graph_uri)
transform_feature_spec = tft_output.transformed_feature_spec()

for input_features, target in data.get_dataset(
    train_uri, transform_feature_spec, batch_size=3).take(1):
    for key in input_features:
        print(f"{key} ({input_features[key].dtype}): {input_features[key].numpy().tolist()}")
    print(f"target: {target.numpy().tolist()}")

## 5. Custom Model Training

In [None]:
from tfx.components.base import executor_spec
from tfx.components.trainer import executor as trainer_executor

In [None]:
_train_module_file = 'model_src/runner.py'

trainer = tfx.components.Trainer(
    custom_executor_spec=executor_spec.ExecutorClassSpec(trainer_executor.GenericExecutor),
    module_file=_train_module_file,
    transformed_examples=transform.outputs.transformed_examples,
    schema=schema_importer.outputs.result,
    transform_graph=transform.outputs.transform_graph,
    train_args=tfx.proto.trainer_pb2.TrainArgs(num_steps=0),
    eval_args=tfx.proto.trainer_pb2.EvalArgs(num_steps=None),
    hyperparameters=hyperparams_gen.outputs.hyperparameters,
    instance_name='Model_Trainer'
)

context.run(trainer, enable_cache=False)

## 6. AutoML Model Training

In [None]:
exclude_columns = ','.join(['trip_start_timestamp'])

automl_trainer = components.automl_trainer(
    project=PROJECT,
    region=REGION,
    dataset_display_name=DATASET_DISPLAYNAME,
    model_display_name=AUTOML_MODEL_DISPLAYNAME,
    target_column=features.TARGET_FEATURE_NAME,
    data_split_column='data_split',
    exclude_cloumns=exclude_columns,
    schema=schema_importer.outputs.result,
)

context.run(automl_trainer, enable_cache=False)

In [None]:
automl_trainer.outputs.uploaded_model.get()[0].get_string_custom_property('model_uri')

## 7. Custom Model Evaluation

In [None]:
from tfx.components import Evaluator

In [None]:
eval_config = tfma.EvalConfig(
    model_specs=[
        tfma.ModelSpec(
            signature_name='serving_tf_example',
            label_key=features.TARGET_FEATURE_NAME,
            prediction_key='probabilities')
    ],
    slicing_specs=[
        tfma.SlicingSpec(),
    ],
    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[   
                tfma.MetricConfig(class_name='ExampleCount'),
                tfma.MetricConfig(
                    class_name='BinaryAccuracy',
                    threshold=tfma.MetricThreshold(
                        value_threshold=tfma.GenericValueThreshold(
                            lower_bound={'value': 0.8}))),
        ])
    ])


evaluator = Evaluator(
    examples=test_example_gen.outputs.examples,
    example_splits=['test'],
    model=trainer.outputs.model,
    eval_config=eval_config,
    schema=schema_importer.outputs.result
)

context.run(evaluator, enable_cache=False)

In [None]:
evaluation_results = evaluator.outputs.evaluation.get()[0].uri
print("validation_ok:", tfma.load_validation_result(evaluation_results).validation_ok)

for entry in list(tfma.load_metrics(evaluation_results))[0].metric_keys_and_values:
    print(entry.key.name, ":", round(entry.value.double_value.value, 3))

## 8. Models Validation

### Get AutoML evaluation results

In [None]:
automl_metric_gen = components.automl_metrics_gen(
    project=PROJECT,
    region=REGION,
    uploaded_model=automl_trainer.outputs.uploaded_model
)

context.run(automl_metric_gen, enable_cache=False)

### Compare the evaluation results of the custom model and the AutoML model

In [None]:
validator = components.custom_model_validator(
    model_evaluation=evaluator.outputs.evaluation,
    uploaded_model_evaluation=automl_metric_gen.outputs.evaluation,
)

context.run(validator, enable_cache=False)

## 9. Model Pushing

In [None]:
exported_model_location = os.path.join(MODEL_REGISTRY, f'{DATASET_DISPLAYNAME}_classifier')

push_destination=tfx.proto.pusher_pb2.PushDestination(
    filesystem=tfx.proto.pusher_pb2.PushDestination.Filesystem(
        base_directory=exported_model_location,
    )
)

pusher = tfx.components.Pusher(
    model=trainer.outputs.model,
    #model_blessing=evaluator.outputs.blessing,
    model_blessing=validator.outputs.blessing,
    push_destination=push_destination
)

context.run(pusher, enable_cache=False)

## 10. Model Upload to AI Platform

In [None]:
serving_runtime ='tf2-cpu.2-3'
serving_image_uri = f"gcr.io/cloud-aiplatform/prediction/{serving_runtime}:latest"

aip_model_uploader = components.aip_model_uploader(
    project=PROJECT,
    region=REGION,
    model_display_name=CUSTOM_MODEL_DISPLAYNAME,
    pushed_model_location=exported_model_location,
    serving_image_uri=serving_image_uri,
)

context.run(aip_model_uploader, enable_cache=False)

In [None]:
aip_model_uploader.outputs.uploaded_model.get()[0].get_string_custom_property('model_uri')