In [None]:
import os
import tempfile
import logging

import tensorflow as tf
import tensorflow_model_analysis as tfma
from tfx import v1 as tfx
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

%load_ext tfx.orchestration.experimental.interactive.notebook_extensions.skip

In [None]:
%%skip_for_export

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

In [None]:
_tfx_root = tfx.__path__[0]
_account_root = os.path.join(_tfx_root, 'examples/account')
_serving_model_dir = os.path.join(tempfile.mkdtemp(), 'serving_model/account_simple')

In [None]:
context = InteractiveContext()

In [None]:
example_gen = tfx.components.CsvExampleGen(input_base='/home/jupyter/data')
context.run(example_gen)

In [None]:
%%writefile component.py

from typing import Optional

from tfx import types
from tfx.dsl.components.base import base_component
from tfx.dsl.components.base import executor_spec
from tfx.types import channel_utils
from tfx.types import standard_artifacts
from tfx.types.component_spec import ChannelParameter
from tfx.types.component_spec import ExecutionParameter

import executor

class DownSampleSpec(types.ComponentSpec):

  PARAMETERS = {
      'ratio': ExecutionParameter(type=int),
  }
  INPUTS = {
      'input_data': ChannelParameter(type=standard_artifacts.Examples),
  }
  OUTPUTS = {
      'output_data': ChannelParameter(type=standard_artifacts.Examples),
  }


class DownSample(base_component.BaseComponent):

  SPEC_CLASS = DownSampleSpec
  EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(executor.Executor)

  def __init__(self,
               input_data: types.Channel = None,
               output_data: types.Channel = None,
               ratio: Optional[int] = 50):

    if not output_data:
        output_data = channel_utils.as_channel([standard_artifacts.Examples()])

    spec = DownSampleSpec(input_data=input_data,
                              output_data=output_data, ratio=ratio)
    super().__init__(spec=spec)

In [None]:
%%writefile executor.py

import json
import os
import random
from typing import Any, Dict, List

from tfx import types
from tfx.dsl.components.base import base_executor
from tfx.dsl.io import fileio
from tfx.types import artifact_utils
from tfx.utils import io_utils
import tensorflow as tf

def downsamplefile(input_dir, output_dir, filename, ratio):
    input_uri = os.path.join(input_dir, filename)
    output_uri = os.path.join(output_dir, filename) 
    
    dataset = tf.data.TFRecordDataset(input_uri, compression_type="GZIP")
    
    def decode_fn(record_bytes):
        return tf.io.parse_single_example(
           record_bytes,
           {"Digit0": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit1": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit2": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit3": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit4": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit5": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit6": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Digit7": tf.io.FixedLenFeature([], dtype=tf.int64),
            "Valid":  tf.io.FixedLenFeature([], dtype=tf.string)}
     ) 
    
    with tf.io.TFRecordWriter(output_uri, options=tf.io.TFRecordOptions(compression_type="GZIP")) as writer:
        for sample in dataset.map(decode_fn):
             if sample['Valid'] == "True" or (sample['Valid'] == "False" and random.randint(1,ratio) == 1):
                record_bytes = tf.train.Example(features=tf.train.Features(feature={
                "Digit0": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit0']])),
                "Digit1": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit1']])),
                "Digit2": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit2']])),
                "Digit3": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit3']])),
                "Digit4": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit4']])),
                "Digit5": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit5']])),
                "Digit6": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit6']])),
                "Digit7": tf.train.Feature(int64_list=tf.train.Int64List(value=[sample['Digit7']])),
                "Valid": tf.train.Feature(bytes_list=tf.train.BytesList(value=[sample['Valid'].numpy()]))
                })).SerializeToString()
            
                writer.write(record_bytes)      

class Executor(base_executor.BaseExecutor):
    
    def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]) -> None:

        self._log_startup(input_dict, output_dict, exec_properties)

        input_artifact = artifact_utils.get_single_instance(input_dict['input_data'])
        output_artifact = artifact_utils.get_single_instance(output_dict['output_data'])
        output_artifact.split_names = input_artifact.split_names
        ratio = exec_properties['ratio']
        
        split_to_instance = {}

        for split in json.loads(input_artifact.split_names):
            uri = artifact_utils.get_split_uri([input_artifact], split)
            split_to_instance[split] = uri

        for split, instance in split_to_instance.items():
            input_dir = instance
            output_dir = artifact_utils.get_split_uri([output_artifact], split)
            for filename in fileio.listdir(input_dir):          
                if "train" in input_dir:
                    io_utils.copy_dir(input_dir,output_dir)
                    downsamplefile(input_dir, output_dir, filename, ratio)
                else:
                    input_uri = os.path.join(input_dir, filename)
                    output_uri = os.path.join(output_dir, filename)
                    io_utils.copy_file(src=input_uri, dst=output_uri, overwrite=True)            


In [None]:
from component import DownSample

downsample_gen = DownSample(input_data=example_gen.outputs['examples'], ratio=50)
context.run(downsample_gen)

In [None]:
statistics_gen = tfx.components.StatisticsGen(
    examples=downsample_gen.outputs['output_data'])
context.run(statistics_gen)

In [None]:
schema_gen = tfx.components.SchemaGen(
    statistics=statistics_gen.outputs['statistics'],
    infer_feature_shape=False)
context.run(schema_gen)

In [None]:
example_validator = tfx.components.ExampleValidator(
    statistics=statistics_gen.outputs['statistics'],
    schema=schema_gen.outputs['schema'])
context.run(example_validator)

In [None]:
_account_transform_module_file = 'account_transform.py'

In [None]:
%%writefile {_account_transform_module_file}

import tensorflow as tf
import tensorflow_transform as tft

DIGIT_KEYS = ['Digit0','Digit1','Digit2','Digit3','Digit4','Digit5','Digit6','Digit7']
LABEL_KEY = 'Valid'

def _transformed_name(key):
  return key + '_xf'

def preprocessing_fn(inputs):

    outputs = {}
    for key in DIGIT_KEYS:
        outputs[_transformed_name(key)] = _fill_in_missing(inputs[key])
          
    outputs[_transformed_name(LABEL_KEY)] = tft.compute_and_apply_vocabulary(_fill_in_missing(inputs[LABEL_KEY]))
                                                                           
    return outputs

def _fill_in_missing(x):
  if not isinstance(x, tf.sparse.SparseTensor):
    return x

  default_value = '' if x.dtype == tf.string else 0
  return tf.squeeze(
      tf.sparse.to_dense(
          tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1]),
          default_value),
      axis=1)

In [None]:
transform = tfx.components.Transform(
    examples=downsample_gen.outputs['output_data'],
    schema=schema_gen.outputs['schema'],
    module_file=os.path.abspath(_account_transform_module_file))
context.run(transform)

In [None]:
train_uri = os.path.join(transform.outputs['transformed_examples'].get()[0].uri, 'Split-train')
tfrecord_filenames = [os.path.join(train_uri, name)
                      for name in os.listdir(train_uri)]
dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type="GZIP")

In [None]:
_account_trainer_module_file = 'account_trainer.py'

In [None]:
%%writefile {_account_trainer_module_file}

from typing import List, Text

import os
import tensorflow as tf
import tensorflow_transform as tft
from tfx import v1 as tfx
from tfx_bsl.public import tfxio

DIGIT_KEYS = ['Digit0','Digit1','Digit2','Digit3','Digit4','Digit5','Digit6','Digit7']
LABEL_KEY = 'Valid'

def _transformed_name(key):
  return key + '_xf'

def _transformed_names(keys):
  return [_transformed_name(key) for key in keys]

def _input_fn(file_pattern: List[Text],
              data_accessor: tfx.components.DataAccessor,
              tf_transform_output: tft.TFTransformOutput,
              batch_size: int = 200) -> tf.data.Dataset:
    return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_transformed_name(LABEL_KEY)),
      tf_transform_output.transformed_metadata.schema)

def run_fn(fn_args: tfx.components.FnArgs):

  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
    
  train_dataset = _input_fn(fn_args.train_files, fn_args.data_accessor, 
                            tf_transform_output, 40)
  eval_dataset = _input_fn(fn_args.eval_files, fn_args.data_accessor, 
                           tf_transform_output, 40)

  visible = { colname: tf.keras.layers.Input(name = colname, shape=(), dtype=tf.int32) for colname in _transformed_names(DIGIT_KEYS) }
  feature_columns = [tf.feature_column.numeric_column(key, shape=()) for key in _transformed_names(DIGIT_KEYS)]

  features = tf.keras.layers.DenseFeatures(feature_columns)(visible)
  hidden1 = tf.keras.layers.Dense(256, activation='relu')(features) 
  hidden2 = tf.keras.layers.Dense(128, activation='relu')(hidden1)
  output = tf.keras.layers.Dense(1, activation='sigmoid')(hidden2)
  model = tf.keras.Model(inputs=visible, outputs=output)
  model.compile(loss='mse', optimizer='rmsprop', metrics=["accuracy"])

  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)
      
  signatures = {
      'serving_default':
          _get_serve_tf_examples_fn(model,
                                    tf_transform_output).get_concrete_function(
                                        tf.TensorSpec(
                                            shape=[None],
                                            dtype=tf.string,
                                            name='examples')),
  }
  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)
    
def _get_serve_tf_examples_fn(model, tf_transform_output):

    model.tft_layer = tf_transform_output.transform_features_layer()

    @tf.function
    def serve_tf_examples_fn(serialized_tf_examples):
        feature_spec = tf_transform_output.raw_feature_spec()
        if not model.tft_layer.built:
            parsed_features_with_label = tf.io.parse_example(
                serialized_tf_examples, feature_spec)
            _ = model.tft_layer(parsed_features_with_label)
        feature_spec.pop(LABEL_KEY)
        parsed_features = tf.io.parse_example(serialized_tf_examples, feature_spec)
        transformed_features = model.tft_layer(parsed_features)
        return model(transformed_features)
    return serve_tf_examples_fn

In [None]:
trainer = tfx.components.Trainer(
    module_file=os.path.abspath(_account_trainer_module_file),
    examples=transform.outputs['transformed_examples'],
    transform_graph=transform.outputs['transform_graph'],
    schema=schema_gen.outputs['schema'],
    train_args=tfx.proto.TrainArgs(num_steps=10000),
    eval_args=tfx.proto.EvalArgs(num_steps=5000))
context.run(trainer)

In [None]:
LABEL_KEY = 'Valid'

def _transformed_name(key):
  return key + '_xf'

eval_config = tfma.EvalConfig(
    model_specs=[tfma.ModelSpec(signature_name="serving_default", label_key=_transformed_name(LABEL_KEY), preprocessing_function_names=["tft_layer"])],

    metrics_specs=[
        tfma.MetricsSpec(
            metrics=[
                tfma.MetricConfig(class_name='FalsePositives'),
                tfma.MetricConfig(class_name='TruePositives'),
                tfma.MetricConfig(class_name='FalseNegatives'),
                tfma.MetricConfig(class_name='TrueNegatives'),
                tfma.MetricConfig(class_name='ExampleCount'),
                tfma.MetricConfig(class_name='BinaryAccuracy',              
                  threshold=tfma.MetricThreshold(
                      value_threshold=tfma.GenericValueThreshold(
                          lower_bound={'value': 0.5}),
                      change_threshold=tfma.GenericChangeThreshold(
                          direction=tfma.MetricDirection.HIGHER_IS_BETTER,
                          absolute={'value': -1e-10})))
            ]
        )
    ]
)

In [None]:
model_resolver = tfx.dsl.Resolver(
      strategy_class=tfx.dsl.experimental.LatestBlessedModelStrategy,
      model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),
      model_blessing=tfx.dsl.Channel(
          type=tfx.types.standard_artifacts.ModelBlessing)).with_id(
              'latest_blessed_model_resolver')
context.run(model_resolver)

evaluator = tfx.components.Evaluator(
    examples=downsample_gen.outputs['output_data'],
    model=trainer.outputs['model'],
    baseline_model=model_resolver.outputs['model'],
    eval_config=eval_config)
context.run(evaluator)

In [None]:
import tensorflow_model_analysis as tfma

PATH_TO_RESULT = evaluator.outputs['evaluation'].get()[0].uri
tfma_result = tfma.load_eval_result(PATH_TO_RESULT)
tfma_result

In [None]:
blessing_uri = evaluator.outputs['blessing'].get()[0].uri
!ls -l {blessing_uri}

In [None]:
pusher = tfx.components.Pusher(
    model=trainer.outputs['model'],
    model_blessing=evaluator.outputs['blessing'],
    push_destination=tfx.proto.PushDestination(
        filesystem=tfx.proto.PushDestination.Filesystem(
            base_directory=_serving_model_dir)))
context.run(pusher)

In [None]:
push_uri = pusher.outputs['pushed_model'].get()[0].uri
%env MODELDIR = {push_uri+'/'}
!saved_model_cli run --dir $MODELDIR --tag_set serve --signature_def serving_default \
  --input_examples 'examples=[{"Digit0":[3],"Digit1":[1],"Digit2":[6],"Digit3":[0],"Digit4":[0],"Digit5":[4],"Digit6":[9],"Digit7":[4]}]'