# ML with TensorFlow Extended (TFX) Pipelines
1. Extracting the new training data from the source using [ExampleGen](https://www.tensorflow.org/tfx/guide/examplegen) component.
2. Validating new training data
    * Generating statistics from the the incoming data using [StatisticsGen](https://www.tensorflow.org/tfx/guide/statsgen) component.
    * Importing a fixed raw schema using [ImporterNode](https://github.com/tensorflow/tfx/blob/master/tfx/components/common_nodes/importer_node.py) component.
    * Validating data based on the schema using [ExampleValidator](https://www.tensorflow.org/tfx/guide/exampleval) component.
5. Transforming the data for ML using the [Transform](https://www.tensorflow.org/tfx/guide/transform) component.
6. Training the model using the [Trainer](https://www.tensorflow.org/tfx/guide/trainer) component.
7. Evaluate the model using the [Evaluator](https://www.tensorflow.org/tfx/guide/evaluator) component.
8. Validate the model using a [Custom TFX](https://www.tensorflow.org/tfx/guide/custom_component) component.
9. Push the the blessed model to serving locationusing [Pusher](https://www.tensorflow.org/tfx/guide/pusher) component.
10. Query the [ML Metadata](https://www.tensorflow.org/tfx/guide/mlmd) DB

In [None]:
import os
import tfx
import tensorflow as tf
import tensorflow.io as tf_io
print("Tensorflow Version:", tf.__version__)
print("TFX Version:", tfx.__version__)

In [None]:
WORKSPACE = 'workspace'
DATA_DIR = WORKSPACE + '/data'
RAW_SCHEMA_DIR = WORKSPACE + '/raw_schema'
OUTPUT_DIR = WORKSPACE + '/artifacts'

REMOVE_ARTIFACTS = True
if REMOVE_ARTIFACTS:
    print("Removing previous artifacts...")
    tf_io.gfile.rmtree(OUTPUT_DIR)

## 0. Create Interactive Context
This will use an ephemeral SQLite MLMD connection contained in the pipeline_root directory with file name "metadata.sqlite" will be used.

In [None]:
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

PIPELINE_NAME = 'tfx-census-classification'

context = InteractiveContext(
    pipeline_name=PIPELINE_NAME,
    pipeline_root=OUTPUT_DIR,
    metadata_connection_config=None
)

In [None]:
from pprint import pprint
pprint("Standard Artifact types:")
pprint([a for a in dir(tfx.types.standard_artifacts) if a[0].isupper()])

## 1. Data Ingestion (ExampleGen)
1. Reads the CSV data files (expecting to include headers)
2. Split the data to train and eval sets
3. Write the data to TFRecords


* **Inputs**: ExternalPath
* **Ouptpus**: Examples (TFRecords)
* **Properties**: split ratio

In [None]:
from tfx.utils.dsl_utils import external_input
from tfx.proto import example_gen_pb2

output_config = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
        example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=4),
        example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=1)
    ]))


example_gen = tfx.components.CsvExampleGen(
    instance_name='Data_Extraction_Spliting',
    input=external_input(DATA_DIR),
    output_config=output_config
)

context.run(example_gen)

### Read sample of the extracted data...
Since TTF v1.15 is used, you need to enable eager execution. However, this causes problems to subsequent steps.

In [None]:
READ = False

if READ:
    tf.enable_eager_execution()

    import tensorflow_data_validation as tfdv

    train_uri = example_gen.outputs['examples'].get()[0].uri

    tfrecord_filenames = tf.data.Dataset.list_files(train_uri+"*")

    # Create a `TFRecordDataset` to read these files
    dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

    # Iterate over the first 3 records and decode them using a TFExampleDecoder
    for tfrecord in dataset.take(3):
        serialized_example = tfrecord.numpy()
        example = tfdv.TFExampleDecoder().decode(serialized_example)
        pprint(example)
        print("")

## 2. Data Validation
1. Generate the **statistics** for the data to validate.
2. Import the **raw_schema** created in the Data Analysis phase.
3. Validat the **statistics** against the schema and generate **anomalies** (if any).

### 2.1. Generating statistics for the data to validate (StatisticsGen)
* **Inputs**: Examples
* **Outputs**: ExampleStatistics
* **Properries**: None

In [None]:
statistics_gen = tfx.components.StatisticsGen(
    instance_name='Statistics_Generation',
    examples=example_gen.outputs['examples'])
context.run(statistics_gen)

In [None]:
context.show(statistics_gen.outputs['statistics'])

### 2.2. Import the fixed raw schema (ImporterNode)
The **ImporterNode** allows you to import an external artifact to a component.
You need to specifiy:
1. Artifact Type
2. Artifcat Location

In [None]:
schema_importer = tfx.components.common_nodes.importer_node.ImporterNode(
    instance_name='Schema_Importer',
    source_uri=RAW_SCHEMA_DIR,
    artifact_type=tfx.types.standard_artifacts.Schema,
    reimport=False
)

context.run(schema_importer)

In [None]:
context.show(schema_importer.outputs['result'])

### 2.3. Validate the input data statistics (ExampleValidator)
* **Inputs**: ExampleStatistics, Schema
* **Outputs**: ExampleAnomalies (if any)
* **Properties**: None

In [None]:
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_importer.outputs['result'],
    instance_name="Data_Validation"
)

context.run(example_validator)

In [None]:
context.show(example_validator.outputs['anomalies'])

## 3. Data Preprocessing

### 3.1. Implement the preprocessing logic

We need to implement the preprocessing logic in a python module: **transform.py**.

* This module is expected to have **preprocessing_fn** method, which accepts a dictionary of the raw features, and returns a dictionary of the transformed features.
* We use the **raw schema** to identify feature types and the required transformation.
* The function is implemented using [TensorFlow Transform](https://www.tensorflow.org/tfx/guide/tft).

### 3.2. Tranform train and eval data (Transform)

The component uses the transform output generated from transforming the train data to transform eval data.
That is, while the train data is **analyzed** and **transformed**, the eval data is **only transformed** uaing the output of the analyze phase (TransformGraph) on the train data.

* **Inputs**: train and eval data (Examples), raw schema (Schema), transformation module (file)
* **outputs**: transformed train and eval data (Examples), transform output (TransformGraph)

In [None]:
tf.logging.set_verbosity(tf.logging.ERROR)

_transform_module_file = 'modules/transform.py'

transform = tfx.components.Transform(
    examples=example_gen.outputs['examples'],
    schema=schema_importer.outputs['result'],
    module_file=_transform_module_file,
    instance_name="Data_Transformation"
)

context.run(transform)

In [None]:
uri = transform.outputs['transform_graph'].get()[0].uri
os.listdir(uri)

### Read sample of the transformed data
Since TTF v1.15 is used, you need to enable eager execution. However, this causes problems to subsequent steps.

In [None]:
READ = False

if READ:
    tf.enable_eager_execution()

    train_uri = transform.outputs['transformed_examples'].get()[1].uri
    tfrecord_filenames = [os.path.join(train_uri, name)
                          for name in os.listdir(train_uri)]
    dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")
    decoder = tfdv.TFExampleDecoder()
    for tfrecord in dataset.take(3):
        serialized_example = tfrecord.numpy()
        example = decoder.decode(serialized_example)
        pprint(example)
        print("")

## 4. Train the Model (Trainer)

### 4.1 Implement a train Python module.

Create a Python module containing a **trainer_fn** function, which must return an estimator.

The **trainer_fn** receives the following parameters:
* hparams: currently includes train_steps, eval_steps, and transform_output.
* schema: the raw data schema. This is used to create the serving input function to export the model.

The **trainer_fn** returns a dictionary of the following:
* estimator: The estimator that will be used for training and eval.
* train_spec: Spec for training.
* eval_spec: Spec for eval.
* eval_input_receiver_fn: Input function for eval.

### 4.2 Train the model using the Trainer component
* **Inputs**: train module file with the **trainer_fn**, raw schema (Schema), and transform output (TransformGraph)
* **Outputs**: saved_model (Model)
* **Properties**: train and eval args


In [None]:
tf.logging.set_verbosity(tf.logging.INFO)

_train_module_file = 'modules/train.py'

trainer = tfx.components.Trainer(
    module_file=_train_module_file,
    transformed_examples=transform.outputs['transformed_examples'],
    schema=schema_importer.outputs['result'],
    transform_graph=transform.outputs['transform_graph'],
    train_args=tfx.proto.trainer_pb2.TrainArgs(num_steps=1000),
    eval_args=tfx.proto.trainer_pb2.EvalArgs(num_steps=None),
    instance_name='Census_Classifier_Trainer'
)

context.run(trainer)

In [None]:
train_uri = trainer.outputs['model'].get()[0].uri
serving_model_path = os.path.join(train_uri, 'serving_model_dir', 'export', 'census')
latest_serving_model_path = os.path.join(serving_model_path, max(os.listdir(serving_model_path)))
print(latest_serving_model_path)

## 5. Evaluate the trained model (Evaluator)
* **Inputs**: eval data (Examples), trained model (Model)
* **Outputs** eval metric (ModelEvaluation)
* **Properties**: Slicing Specs

In [None]:
feature_slicing_spec=tfx.proto.evaluator_pb2.FeatureSlicingSpec(
    specs=[
        tfx.proto.evaluator_pb2.SingleSlicingSpec(),
        tfx.proto.evaluator_pb2.SingleSlicingSpec(column_for_slicing=['occupation'])
    ]
)


model_analyzer = tfx.components.Evaluator(
    examples=example_gen.outputs['examples'],
    model=trainer.outputs['model'],
    feature_slicing_spec=feature_slicing_spec,
    instance_name="Occupation_based_Evaluator"
)

context.run(model_analyzer)


In [None]:
import tensorflow_model_analysis as tfma

results_uri = model_analyzer.outputs['output'].get()[0].uri
tfma.load_eval_result(results_uri).slicing_metrics[:2]

## 6. Validate the Trained Model

We will create a Custom TFX Component that validates the trained model based on its produced evaluation metric.

The custom validator will **bless** the model if:
1. Overal accuracy is greater than 85%.
2. Accuracy per **Occupation** slice is at most 10% less than the overall accuracy.

* **Inputs**: Evaluation Metric (ModelEvaluation), trained model (Model)
* **Outputs**: blessing (ModelBlessing)
* **Properties**: accuracy_threshold, slice_accuracy_tolerance

### 6.1. Create ComponentSepc

In [None]:
from tfx.types import standard_artifacts
from tfx.types.component_spec import ChannelParameter
from tfx.types.component_spec import ExecutionParameter

class AccuracyValidatorSpec(tfx.types.ComponentSpec):
    
    INPUTS = {
        'eval_results': ChannelParameter(type=standard_artifacts.ModelEvaluation),
        'model': ChannelParameter(type=standard_artifacts.Model),
    }
    
    OUTPUTS = {
      'blessing': ChannelParameter(type=standard_artifacts.ModelBlessing),
    }
    
    PARAMETERS = {
        'accuracy_threshold': ExecutionParameter(type=float),
        'slice_accuracy_tolerance': ExecutionParameter(type=float),
    }


### 6.2. Create Custom Exectutor

In [None]:
from tfx.components.base import base_executor
from tfx.types import artifact_utils
from tfx.utils import io_utils

class AccuracyValidatorExecutor(base_executor.BaseExecutor):
    
    def Do(self, input_dict, output_dict, exec_properties):
        
        valid = True
        
        self._log_startup(input_dict, output_dict, exec_properties)
        
        accuracy_threshold = exec_properties['accuracy_threshold']
        slice_accuracy_tolerance = exec_properties['slice_accuracy_tolerance']
        min_slice_accuracy = accuracy_threshold - slice_accuracy_tolerance
        print("Accuracy Threshold:", accuracy_threshold)
        print("Slice Accuracy Tolerance:", slice_accuracy_tolerance)
        print("Min Accuracy per Slice:", min_slice_accuracy)
        
        results_uri = input_dict['eval_results'][0].uri
        eval_results = tfma.load_eval_result(results_uri)
        
        overall_acc = eval_results.slicing_metrics[0][1]['']['']['accuracy']['doubleValue']
        print("Overall accuracy:", overall_acc)
        
        if overall_acc >= accuracy_threshold:
            for slicing_metric in eval_results.slicing_metrics:
                slice_acc = slicing_metric[1]['']['']['accuracy']['doubleValue']
                if slice_acc < min_slice_accuracy:
                    print("Slice accuracy value < min accuracy:", slice_acc )
                    valid = False
                    break
        else:
            valid = False
        
        print("Valid:", valid)
        
        blessing = artifact_utils.get_single_instance(output_dict['blessing'])
        
        # Current model.
        current_model = artifact_utils.get_single_instance(input_dict['model'])
        blessing.set_string_custom_property('current_model', current_model.uri)
        blessing.set_int_custom_property('current_model_id', current_model.id)

        if valid:
            io_utils.write_string_file(os.path.join(blessing.uri, 'BLESSED'), '')
            blessing.set_int_custom_property('blessed', 1)
        else:
            io_utils.write_string_file(os.path.join(blessing.uri, 'NOT_BLESSED'), '')
            blessing.set_int_custom_property('blessed', 0)


### 6.3. Create AccuracyModelValidator

In [None]:
from typing import Optional
from tfx import types
from tfx.components.base import base_component
from tfx.components.base import executor_spec

class AccuracyModelValidator(base_component.BaseComponent):

    SPEC_CLASS = AccuracyValidatorSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(AccuracyValidatorExecutor)
    
    def __init__(self,
                 eval_results: types.channel,
                 model: types.channel,
                 accuracy_threshold: float,
                 slice_accuracy_tolerance: float,
                 blessing: Optional[types.Channel] = None,
                 instance_name=None):
        
        blessing = blessing or types.Channel(
            type=standard_artifacts.ModelBlessing,
            artifacts=[standard_artifacts.ModelBlessing()])
        
        spec = AccuracyValidatorSpec(
            eval_results=eval_results, model=model, blessing=blessing, 
            accuracy_threshold=accuracy_threshold, slice_accuracy_tolerance=slice_accuracy_tolerance
        )
        
        super().__init__(spec=spec, instance_name=instance_name)
    

In [None]:
accuracy_model_validator = AccuracyModelValidator(
    eval_results=model_analyzer.outputs['output'],
    model=trainer.outputs['model'],
    accuracy_threshold=0.75,
    slice_accuracy_tolerance=0.15,
    instance_name="Accuracy_Model_Validator"
)

context.run(accuracy_model_validator)

In [None]:
blessing_uri = accuracy_model_validator.outputs.blessing.get()[0].uri
!ls -l {blessing_uri}

## 7. Pushing the Blessed Model (Pusher)
This steps pushes the validated and blessed model to its final destination. This could be:
1. A Model Registry
2. Git Repository
3. API Serving Platform
4. A specific filesystem location
5. ...

### 7.1. Push the blessed model to model registry (filesystem location)

In [None]:
serving_models_location = WORKSPACE + '/model_registry'
!mkdir {serving_models_location}

In [None]:
push_destination=tfx.proto.pusher_pb2.PushDestination(
    filesystem=tfx.proto.pusher_pb2.PushDestination.Filesystem(
        base_directory=serving_models_location)
)

pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    model_blessing=accuracy_model_validator.outputs['blessing'],
    push_destination=push_destination
)

context.run(pusher)

### 7.2. Test the pushed model

In [None]:
latest_serving_model_path = os.path.join(serving_models_location, max(os.listdir(serving_models_location)))
print(latest_serving_model_path)
!saved_model_cli show --dir={latest_serving_model_path} --all

In [None]:
tf.logging.set_verbosity(tf.logging.ERROR)

predictor_fn = tf.contrib.predictor.from_saved_model(
    export_dir = latest_serving_model_path,
    signature_def_key="predict"
)
print("")

output = predictor_fn(
    {
        'age': [34.0],
        'workclass': ['Private'],
        'education': ['Doctorate'],
        'education_num': [10.0],
        'marital_status': ['Married-civ-spouse'],
        'occupation': ['Prof-specialty'],
        'relationship': ['Husband'],
        'race': ['White'],
        'gender': ['Male'],
        'capital_gain': [0.0], 
        'capital_loss': [0.0], 
        'hours_per_week': [40.0],
        'native_country':['Egyptian']
    }
)

output

## 8. Querying Metadata database

In [None]:
import sqlite3
connection = sqlite3.connect(os.path.join(OUTPUT_DIR, 'metadata.sqlite'))
cursor = connection.cursor()

### List tables

In [None]:
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
pprint(cursor.fetchall())

### Query Artifact table

In [None]:
cursor.execute("SELECT * FROM Artifact;")
for entry in cursor.fetchall():
    print(entry)