# TFX on KubeFlow Pipelines Example

This notebook should be run inside a KF Pipelines cluster.

### Install TFX and KFP packages

In [None]:
!pip3 install https://storage.googleapis.com/ml-pipeline/tfx/tfx-0.12.0rc0-py2.py3-none-any.whl 
!pip3 install kfp --upgrade


### Enable DataFlow API for your GKE cluster
<https://console.developers.google.com/apis/api/dataflow.googleapis.com/overview>


## Get the TFX repo with sample pipeline


In [None]:
!git clone https://github.com/tensorflow/tfx

In [None]:
# copy the trainer code to a storage bucket as the TFX pipeline will need that code file in GCS
from tensorflow import gfile
gfile.Copy('tfx/examples/chicago_taxi_pipeline/taxi_utils.py', 'gs://<my bucket>/<path>/taxi_utils.py')

## Configure the TFX pipeline example

Reload this cell by running the load command to get the pipeline configuration file
```
%load tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow.py
```

Configure:
- Set `_input_bucket` to the GCS directory where you've copied taxi_utils.py. I.e. gs://<my bucket>/<path>/
- Set `_output_bucket` to the GCS directory where you've want the results to be written
- Set GCP project ID (replace my-gcp-project). Note that it should be project ID, not project name.

The dataset in BigQuery has 100M rows, you can change the query parameters in WHERE clause to limit the number of rows used.


In [7]:
%load tfx/examples/chicago_taxi_pipeline/taxi_pipeline_kubeflow.py

## Compile the pipeline and submit a run to the Kubeflow cluster

In [None]:
# Get or create a new experiment
import kfp
client = kfp.Client()
experiment_name='TFX Examples'
try:
    experiment_id = client.get_experiment(experiment_name=experiment_name).id
except:
    experiment_id = client.create_experiment(experiment_name).id

pipeline_filename = 'chicago_taxi_pipeline_kubeflow.tar.gz'

#Submit a pipeline run
run_name = 'Run 1'
run_result = client.run_pipeline(experiment_id, run_name, pipeline_filename, {})


### Connect to the ML Metadata Store

In [None]:
!pip3 install ml_metadata

In [None]:
from ml_metadata.metadata_store import metadata_store
from ml_metadata.proto import metadata_store_pb2
import os

connection_config = metadata_store_pb2.ConnectionConfig()
connection_config.mysql.host = os.getenv('MYSQL_SERVICE_HOST')
connection_config.mysql.port = int(os.getenv('MYSQL_SERVICE_PORT'))
connection_config.mysql.database = 'mlmetadata'
connection_config.mysql.user = 'root'
store = metadata_store.MetadataStore(connection_config)

In [None]:
# Get all output artifacts
store.get_artifacts()

In [None]:
# Get a specific artifact type

# TFX types 
# types = ['ModelExportPath', 'ExamplesPath', 'ModelBlessingPath', 'ModelPushPath', 'TransformPath', 'SchemaPath']

store.get_artifacts_by_type('ExamplesPath')