# TFX - Run Training Pipeline locally using BeamDagRunner

## Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import kfp
import tfx
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
import tensorflow as tf
import ml_metadata as mlmd
from ml_metadata.proto import metadata_store_pb2
import logging

logging.getLogger().setLevel(logging.INFO)

In [None]:
LOCAL_WORKSPACE = '_workspace'
MLMD_SQLLITE = os.path.join(LOCAL_WORKSPACE, 'mlmd.sqllite')

REMOVE_LOCAL_WORKSPACE = True

if tf.io.gfile.exists(LOCAL_WORKSPACE) and REMOVE_LOCAL_WORKSPACE:
    print("Removing previous local workspace...")
    tf.io.gfile.rmtree(LOCAL_WORKSPACE)

print("Creating new local workspace...")
tf.io.gfile.mkdir(LOCAL_WORKSPACE)

## Set pipeline configurations

In [None]:
os.environ["DATASET_DISPLAYNAME"] = 'chicago_taxi_tips'
os.environ["PROJECT"] = 'ksalama-cloudml'
os.environ["REGION"] = 'us-central1'
os.environ["GCS_LOCATION"] = "gs://ksalama-cloudml-us/ucaip_demo/chicago_taxi/beam_runner"
os.environ["TRAIN_LIMIT"] = "85000"
os.environ["TEST_LIMIT"] = "15000"
os.environ["BEAM_RUNNER"] = "DirectRunner"
os.environ["TRAINING_RUNNER"] = "local"

In [None]:
from tfx_pipeline import config
for key, value in config.__dict__.items():
    if key.isupper(): print(f'{key}: {value}')

## Congifure local metadata store

In [None]:
gcs_location = os.environ["GCS_LOCATION"]
print(gcs_location)

if tf.io.gfile.exists(gcs_location):
    print("Removing previous artifacts...")
    tf.io.gfile.rmtree(gcs_location)

if tf.io.gfile.exists(MLMD_SQLLITE):
    print("Removing local mlmd SQLite...")
    tf.io.gfile.remove(MLMD_SQLLITE)

metadata_connection_config = metadata_store_pb2.ConnectionConfig()
metadata_connection_config.sqlite.filename_uri = MLMD_SQLLITE
metadata_connection_config.sqlite.connection_mode = 3
print("ML metadata store is ready.")

## Run the pipeline

In [None]:
from tfx_pipeline import pipeline as pipeline_module

In [None]:
pipeline_root = os.path.join(
    config.ARTIFACT_STORE_URI,
    config.PIPELINE_NAME,
)

runner = BeamDagRunner()

pipeline = pipeline_module.create_pipeline(
    metadata_connection_config=metadata_connection_config,
    pipeline_root=pipeline_root,
    num_epochs=50,
    batch_size=512,
    learning_rate=0.0003,
    hidden_units="256,128",
)

runner.run(pipeline)

print("Pipeline finished exection.")