<a href="https://colab.research.google.com/github/bizzengine/tfx/blob/main/iris_tutorial_tfx.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
try:
  import colab
  !pip install --upgrade pip
except:
  pass

[0m

In [None]:
!pip install -U tfx

[0m

In [None]:
import tensorflow as tf
print('TensorFlow version: {}'.format(tf.__version__))
from tfx import v1 as tfx
print('TFX version: {}'.format(tfx.__version__))

TensorFlow version: 2.7.0
TFX version: 1.6.1


In [None]:
import os

PIPELINE_NAME = "iris-simple"

# Output directory to store artifacts generated from the pipeline.
PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)
# Path to a SQLite DB file to use as an MLMD storage.
METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')
# Output directory where created models from the pipeline will be exported.
SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)

from absl import logging
logging.set_verbosity(logging.INFO)  # Set default logging level.


print(PIPELINE_NAME)
print(PIPELINE_ROOT)
print(METADATA_PATH)
print(SERVING_MODEL_DIR)

iris-simple
pipelines/iris-simple
metadata/iris-simple/metadata.db
serving_model/iris-simple


In [None]:
import urllib.request
import tempfile

DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data')  # Create a temporary directory.
_data_url = 'https://raw.githubusercontent.com/bizzengine/tfx/main/data.csv'
_data_filepath = os.path.join(DATA_ROOT, "data.csv")
urllib.request.urlretrieve(_data_url, _data_filepath)


print(DATA_ROOT)
print(_data_url)
print(_data_filepath)
print(urllib.request.urlretrieve(_data_url, _data_filepath))

/tmp/tfx-data2klzvjvw
https://raw.githubusercontent.com/bizzengine/tfx/main/data.csv
/tmp/tfx-data2klzvjvw/data.csv
('/tmp/tfx-data2klzvjvw/data.csv', <http.client.HTTPMessage object at 0x7feeed287690>)


In [None]:
!head {_data_filepath}

SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
5.1,3.5,1.4,0.2,0
4.9,3,1.4,0.2,0
4.7,3.2,1.3,0.2,0
4.6,3.1,1.5,0.2,0
5,3.6,1.4,0.2,0
5.4,3.9,1.7,0.4,0
4.6,3.4,1.4,0.3,0
5,3.4,1.5,0.2,0
4.4,2.9,1.4,0.2,0


In [None]:
_trainer_module_file = 'iris_trainer.py'

In [None]:
%%writefile {_trainer_module_file}

from typing import List
from absl import logging
import tensorflow as tf
from tensorflow import keras
from tensorflow_transform.tf_metadata import schema_utils

from tfx import v1 as tfx
from tfx_bsl.public import tfxio
from tensorflow_metadata.proto.v0 import schema_pb2

_FEATURE_KEYS = [
    'SepalLengthCm', 'SepalWidthCm', 'PetalLengthCm', 'PetalWidthCm'
]
_LABEL_KEY = 'Species'

_TRAIN_BATCH_SIZE = 20
_EVAL_BATCH_SIZE = 10

# Since we're not generating or creating a schema, we will instead create
# a feature spec.  Since there are a fairly small number of features this is
# manageable for this dataset.
_FEATURE_SPEC = {
    **{
        feature: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)
           for feature in _FEATURE_KEYS
       },
    _LABEL_KEY: tf.io.FixedLenFeature(shape=[1], dtype=tf.int64)
}


def _input_fn(file_pattern: List[str],
              data_accessor: tfx.components.DataAccessor,
              schema: schema_pb2.Schema,
              batch_size: int = 200) -> tf.data.Dataset:
  """Generates features and label for training.

  Args:
    file_pattern: List of paths or patterns of input tfrecord files.
    data_accessor: DataAccessor for converting input to RecordBatch.
    schema: schema of the input data.
    batch_size: representing the number of consecutive elements of returned
      dataset to combine in a single batch

  Returns:
    A dataset that contains (features, indices) tuple where features is a
      dictionary of Tensors, and indices is a single Tensor of label indices.
  """
  return data_accessor.tf_dataset_factory(
      file_pattern,
      tfxio.TensorFlowDatasetOptions(
          batch_size=batch_size, label_key=_LABEL_KEY),
      schema=schema).repeat()


def _build_keras_model() -> tf.keras.Model:
  """Creates a DNN Keras model for classifying penguin data.

  Returns:
    A Keras Model.
  """
  # The model below is built with Functional API, please refer to
  # https://www.tensorflow.org/guide/keras/overview for all API options.
  inputs = [keras.layers.Input(shape=(1,), name=f) for f in _FEATURE_KEYS]
  d = keras.layers.concatenate(inputs)
  for _ in range(2):
    d = keras.layers.Dense(8, activation='relu')(d)
  outputs = keras.layers.Dense(3)(d)

  model = keras.Model(inputs=inputs, outputs=outputs)
  model.compile(
      optimizer=keras.optimizers.Adam(1e-2),
      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
      metrics=[keras.metrics.SparseCategoricalAccuracy()])

  model.summary(print_fn=logging.info)
  return model


# TFX Trainer will call this function.
def run_fn(fn_args: tfx.components.FnArgs):
  """Train the model based on given args.

  Args:
    fn_args: Holds args used to train the model as name/value pairs.
  """

  # This schema is usually either an output of SchemaGen or a manually-curated
  # version provided by pipeline author. A schema can also derived from TFT
  # graph if a Transform component is used. In the case when either is missing,
  # `schema_from_feature_spec` could be used to generate schema from very simple
  # feature_spec, but the schema returned would be very primitive.
  schema = schema_utils.schema_from_feature_spec(_FEATURE_SPEC)

  train_dataset = _input_fn(
      fn_args.train_files,
      fn_args.data_accessor,
      schema,
      batch_size=_TRAIN_BATCH_SIZE)
  eval_dataset = _input_fn(
      fn_args.eval_files,
      fn_args.data_accessor,
      schema,
      batch_size=_EVAL_BATCH_SIZE)

  model = _build_keras_model()
  model.fit(
      train_dataset,
      steps_per_epoch=fn_args.train_steps,
      validation_data=eval_dataset,
      validation_steps=fn_args.eval_steps)

  # The result of the training should be saved in `fn_args.serving_model_dir`
  # directory.
  model.save(fn_args.serving_model_dir, save_format='tf')

Writing iris_trainer.py


In [None]:
def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,
                     module_file: str, serving_model_dir: str,
                     metadata_path: str) -> tfx.dsl.Pipeline:
  """Creates a three component penguin pipeline with TFX."""
  # Brings data into the pipeline.
  example_gen = tfx.components.CsvExampleGen(input_base=data_root)

  # Uses user-provided Python function that trains a model.
  trainer = tfx.components.Trainer(
      module_file=module_file,
      examples=example_gen.outputs['examples'],
      train_args=tfx.proto.TrainArgs(num_steps=100),
      eval_args=tfx.proto.EvalArgs(num_steps=5))

  # Pushes the model to a filesystem destination.
  pusher = tfx.components.Pusher(
      model=trainer.outputs['model'],
      push_destination=tfx.proto.PushDestination(
          filesystem=tfx.proto.PushDestination.Filesystem(
              base_directory=serving_model_dir)))

  # Following three components will be included in the pipeline.
  components = [
      example_gen,
      trainer,
      pusher,
  ]

  return tfx.dsl.Pipeline(
      pipeline_name=pipeline_name,
      pipeline_root=pipeline_root,
      metadata_connection_config=tfx.orchestration.metadata
      .sqlite_metadata_connection_config(metadata_path),
      components=components)

In [None]:
tfx.orchestration.LocalDagRunner().run(
  _create_pipeline(
      pipeline_name=PIPELINE_NAME,
      pipeline_root=PIPELINE_ROOT,
      data_root=DATA_ROOT,
      module_file=_trainer_module_file,
      serving_model_dir=SERVING_MODEL_DIR,
      metadata_path=METADATA_PATH))

INFO:absl:Generating ephemeral wheel package for '/content/iris_trainer.py' (including modules: ['iris_trainer']).
INFO:absl:User module package has hash fingerprint version 62447b7ba09247617285cfc09d22d2d4bcd425fea2035be10e0b6df224da85b7.
INFO:absl:Executing: ['/usr/bin/python3', '/tmp/tmpzbz24stw/_tfx_generated_setup.py', 'bdist_wheel', '--bdist-dir', '/tmp/tmpixbvws4s', '--dist-dir', '/tmp/tmp60e8a5zp']
INFO:absl:Successfully built user code wheel distribution at 'pipelines/iris-simple/_wheels/tfx_user_code_Trainer-0.0+62447b7ba09247617285cfc09d22d2d4bcd425fea2035be10e0b6df224da85b7-py3-none-any.whl'; target user module is 'iris_trainer'.
INFO:absl:Full user module path is 'iris_trainer@pipelines/iris-simple/_wheels/tfx_user_code_Trainer-0.0+62447b7ba09247617285cfc09d22d2d4bcd425fea2035be10e0b6df224da85b7-py3-none-any.whl'
INFO:absl:Using deployment config:
 executor_specs {
  key: "CsvExampleGen"
  value {
    beam_executable_spec {
      python_executor_spec {
        class_path: 

INFO:absl:Processing input csv data /tmp/tfx-data2klzvjvw/* to TFExample.
INFO:absl:Examples generated.
INFO:absl:Value type <class 'NoneType'> of key version in exec_properties is not supported, going to drop it
INFO:absl:Value type <class 'list'> of key _beam_pipeline_args in exec_properties is not supported, going to drop it
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 1 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'examples': [Artifact(artifact: uri: "pipelines/iris-simple/CsvExampleGen/examples/1"
custom_properties {
  key: "input_fingerprint"
  value {
    string_value: "split:single_split,num_files:1,total_bytes:2771,xor_checksum:1644889963,sum_checksum:1644889963"
  }
}
custom_properties {
  key: "name"
  value {
    string_value: "iris-simple:2022-02-15T01:52:44.591870:CsvExampleGen:examples:0"
  }
}
custom_properties {
  key: "span"
  value {
    int_value: 0
  }
}
custom_pr





INFO:tensorflow:Assets written to: pipelines/iris-simple/Trainer/model/2/Format-Serving/assets


INFO:tensorflow:Assets written to: pipelines/iris-simple/Trainer/model/2/Format-Serving/assets
INFO:absl:Training complete. Model written to pipelines/iris-simple/Trainer/model/2/Format-Serving. ModelRun written to pipelines/iris-simple/Trainer/model_run/2
INFO:absl:Cleaning up stateless execution info.
INFO:absl:Execution 2 succeeded.
INFO:absl:Cleaning up stateful execution info.
INFO:absl:Publishing output artifacts defaultdict(<class 'list'>, {'model': [Artifact(artifact: uri: "pipelines/iris-simple/Trainer/model/2"
custom_properties {
  key: "name"
  value {
    string_value: "iris-simple:2022-02-15T01:52:44.591870:Trainer:model:0"
  }
}
custom_properties {
  key: "tfx_version"
  value {
    string_value: "1.6.1"
  }
}
, artifact_type: name: "Model"
)], 'model_run': [Artifact(artifact: uri: "pipelines/iris-simple/Trainer/model_run/2"
custom_properties {
  key: "name"
  value {
    string_value: "iris-simple:2022-02-15T01:52:44.591870:Trainer:model_run:0"
  }
}
custom_properties {


In [None]:
# List files in created model directory.
!find {SERVING_MODEL_DIR}

serving_model/iris-simple
serving_model/iris-simple/1644889979
serving_model/iris-simple/1644889979/variables
serving_model/iris-simple/1644889979/variables/variables.index
serving_model/iris-simple/1644889979/variables/variables.data-00000-of-00001
serving_model/iris-simple/1644889979/saved_model.pb
serving_model/iris-simple/1644889979/keras_metadata.pb
serving_model/iris-simple/1644889979/assets
