# TFX pipeline example - Chicago Taxi tips prediction

## Overview
[Tensorflow Extended (TFX)](https://github.com/tensorflow/tfx) is a Google-production-scale machine
learning platform based on TensorFlow. It provides a configuration framework to express ML pipelines
consisting of TFX components, which brings the user large-scale ML task orchestration, artifact lineage, as well as the power of various [TFX libraries](https://www.tensorflow.org/resources/libraries-extensions). Kubeflow Pipelines can be used as the orchestrator supporting the 
execution of a TFX pipeline.

This sample demonstrates how to author a ML pipeline in TFX and run it on a KFP deployment. 

## Permission

This pipeline requires Google Cloud Storage permission to run. 
If KFP was deployed through K8S marketplace, please follow instructions in [the guideline](https://github.com/kubeflow/pipelines/blob/master/manifests/gcp_marketplace/guide.md#gcp-service-account-credentials)
to make sure the service account has `storage.admin` role.

In [None]:
!python3 -m pip install pip --upgrade --quiet --user
!python3 -m pip install kfp --upgrade --quiet --user
!python3 -m pip install tfx --upgrade --quiet --user

In this example we'll need a very recent version of TFX SDK to leverage the [`RuntimeParameter`](https://github.com/tensorflow/tfx/blob/93ea0b4eda5a6000a07a1e93d93a26441094b6f5/tfx/orchestration/data_types.py#L137) feature.

## RuntimeParameter in TFX DSL
Currently, TFX DSL only supports parameterizing field in the `PARAMETERS` section of `ComponentSpec`, see [here](https://github.com/tensorflow/tfx/blob/93ea0b4eda5a6000a07a1e93d93a26441094b6f5/tfx/types/component_spec.py#L126). This prevents runtime-parameterizing the pipeline topology. Also, if the declared type of the field is a protobuf, the user needs to pass in a dictionary with exactly the same names for each field, and specify one or more value as `RuntimeParameter` objects. In other word, the dictionary should be able to be passed in to [`ParseDict()` method](https://github.com/protocolbuffers/protobuf/blob/04a11fc91668884d1793bff2a0f72ee6ce4f5edd/python/google/protobuf/json_format.py#L433) and produce the correct pb message.

In [None]:
!python3 -m pip install --quiet --index-url https://test.pypi.org/simple/ tfx==0.16.0.dev20191212 --user

In [1]:
import os
from typing import Optional, Text

import kfp
from kfp import dsl

from tfx.components import Evaluator
from tfx.components import CsvExampleGen
from tfx.components import ExampleValidator
from tfx.components import ModelValidator
from tfx.components import Pusher
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.orchestration import data_types
from tfx.orchestration import pipeline
from tfx.orchestration.kubeflow import kubeflow_dag_runner
from tfx.proto import pusher_pb2
from tfx.utils.dsl_utils import external_input



In [2]:
# In TFX MLMD schema, pipeline name is used as the unique id of each pipeline.
# Assigning workflow ID as part of pipeline name allows the user to bypass
# some schema checks which are redundant for experimental pipelines.
pipeline_name = 'taxi_pipeline_with_parameters_' + kfp.dsl.RUN_ID_PLACEHOLDER

# Path of pipeline data root, should be a GCS path.
# Note that when running on KFP, the pipeline root is always a runtime parameter.
pipeline_root = os.path.join('gs://my-bucket', 'tfx_taxi_simple',
                              kfp.dsl.RUN_ID_PLACEHOLDER)

# Location of input data, should be a GCS path under which there is a csv file.
data_root_param = data_types.RuntimeParameter(
    name='data-root',
    default='gs://ml-pipeline-playground/tfx_taxi_simple/data',
    ptype=Text,
)

# Path to the module file, GCS path.
# Module file is one of the recommended way to provide customized logic for component
# includeing Trainer and Transformer.
# See https://github.com/tensorflow/tfx/blob/93ea0b4eda5a6000a07a1e93d93a26441094b6f5/tfx/components/trainer/component.py#L38
taxi_module_file_param = data_types.RuntimeParameter(
    name='module-file',
    default='gs://ml-pipeline-playground/tfx_taxi_simple/modules/tfx_taxi_utils_1205.py',
    ptype=Text,
)

# Number of epochs in training.
train_steps = data_types.RuntimeParameter(
    name='train-steps',
    default=10,
    ptype=int,
)

# Number of epochs in evaluation.
eval_steps = data_types.RuntimeParameter(
    name='eval-steps',
    default=5,
    ptype=int,
)

# Column name for slicing.
slicing_column = data_types.RuntimeParameter(
    name='slicing-column',
    default='trip_start_hour',
    ptype=Text,
)

## TFX Components

Please refer to the [official guide](https://www.tensorflow.org/tfx/guide#tfx_pipeline_components) for the detailed explanation and purpose of each TFX component.

In [3]:
# The input data location is parameterized by _data_root_param
examples = external_input(data_root_param)
example_gen = CsvExampleGen(input=examples)

In [4]:
statistics_gen = StatisticsGen(input_data=example_gen.outputs['examples'])



In [5]:
infer_schema = SchemaGen(
    stats=statistics_gen.outputs['statistics'], infer_feature_shape=False)



In [6]:
validate_stats = ExampleValidator(
  stats=statistics_gen.outputs['statistics'],
  schema=infer_schema.outputs['schema'])



In [7]:
# The module file used in Transform and Trainer component is paramterized by
# _taxi_module_file_param.
transform = Transform(
  input_data=example_gen.outputs['examples'],
  schema=infer_schema.outputs['schema'],
  module_file=taxi_module_file_param)



In [8]:
# The numbers of steps in train_args are specified as RuntimeParameter with
# name 'train-steps' and 'eval-steps', respectively.
trainer = Trainer(
  module_file=taxi_module_file_param,
  transformed_examples=transform.outputs['transformed_examples'],
  schema=infer_schema.outputs['schema'],
  transform_output=transform.outputs['transform_graph'],
  train_args={'num_steps': train_steps},
  eval_args={'num_steps': eval_steps})



In [9]:
# The name of slicing column is specified as a RuntimeParameter.
model_analyzer = Evaluator(
  examples=example_gen.outputs['examples'],
  model_exports=trainer.outputs['model'],
  feature_slicing_spec=dict(specs=[{
      'column_for_slicing': [slicing_column]
  }]))



In [10]:
model_validator = ModelValidator(
  examples=example_gen.outputs['examples'], model=trainer.outputs['model'])


In [11]:
# Currently we use this hack to ensure push_destination can
# be correctly parameterized and interpreted.
# pipeline root will be specified as a dsl.PipelineParam with the name
# pipeline-root, see:
# https://github.com/tensorflow/tfx/blob/1c670e92143c7856f67a866f721b8a9368ede385/tfx/orchestration/kubeflow/kubeflow_dag_runner.py#L226
pipeline_root_param = dsl.PipelineParam(name='pipeline-root')
pusher = Pusher(
  model_export=trainer.outputs['model'],
  model_blessing=model_validator.outputs['blessing'],
  push_destination=pusher_pb2.PushDestination(
      filesystem=pusher_pb2.PushDestination.Filesystem(
          base_directory=os.path.join(
              str(pipeline_root_param), 'model_serving'))))




In [15]:
# Create the DSL pipeline object.
# This pipeline obj carries the business logic of the pipeline, but no runner-specific information
# was included.
dsl_pipeline = pipeline.Pipeline(
  pipeline_name=pipeline_name,
  pipeline_root=pipeline_root,
  components=[
      example_gen, statistics_gen, infer_schema, validate_stats, transform,
      trainer, model_analyzer, model_validator, pusher
  ],
  enable_cache=False,
  beam_pipeline_args=['--direct_num_workers=%d' % 4],
)

In [16]:
# Specify a TFX docker image. For the full list of tags please see:
# https://hub.docker.com/r/tensorflow/tfx/tags
tfx_image = 'gcr.io/jxzheng-helloworld/patched-tfx:latest'
config = kubeflow_dag_runner.KubeflowDagRunnerConfig(
      kubeflow_metadata_config=kubeflow_dag_runner
      .get_default_kubeflow_metadata_config(),
      tfx_image=tfx_image)
kfp_runner = kubeflow_dag_runner.KubeflowDagRunner(config=config)
# KubeflowDagRunner compiles the DSL pipeline object into KFP pipeline package.
# By default it is named <pipeline_name>.tar.gz
kfp_runner.run(dsl_pipeline)

In [17]:
run_result = kfp.Client(
    host='450a951aa8610299-dot-us-central2.pipelines.googleusercontent.com'
).create_run_from_pipeline_package(
    pipeline_name + '.tar.gz', 
    arguments={
        'pipeline-root': 'gs://jxzheng-helloworld/taxi_simple/' + kfp.dsl.RUN_ID_PLACEHOLDER,
        #'module-file': '<gcs path to the module file>',  # delete this line to use default module file.
        #'data-root': '<gcs path to the data>'  # delete this line to use default data.
})