# Custom TFX Components

## Download example dataset

For this workshop we'll be using the public cats & dogs dataset created by Microsoft. The data set contains two folders: "Dog" and "Cat". 

In [None]:
!rm -rf /content/PetImages/
!rm *.zip

!wget https://download.microsoft.com/download/3/E/1/3E1C3F21-ECDB-4869-8368-6DEBA77B919F/kagglecatsanddogs_3367a.zip
!unzip -q -d /content/ /content/kagglecatsanddogs_3367a.zip

!echo "Count images"
!ls -U /content/PetImages/Cat | wc -l
!ls -U /content/PetImages/Dog | wc -l

!echo "Reduce images for demo purposes"
!cd /content/PetImages/Cat && ls -U | head -12000 | xargs rm 
!cd /content/PetImages/Dog && ls -U | head -12000 | xargs rm 

!echo "Count images after removal"
!ls -U /content/PetImages/Cat | wc -l
!ls -U /content/PetImages/Dog | wc -l

!echo "Rename filename and move them"
!rm -rf /content/images/
!mkdir -p /content/images/
%cd /content/PetImages/Cat
!rename 's/^/cat-/' *.jpg
!mv *.jpg /content/images/
%cd /content/PetImages/Dog
!rename 's/^/dog-/' *.jpg
!mv *.jpg /content/images/
!ls -U /content/images/ | wc -l

## Install required Python packages

In [None]:
!pip install tfx==0.27.0

import tfx 

## Import required packages & modules


In [None]:
import logging
import os
import random
from pathlib import Path
from typing import Any, Dict, Iterable, List, Text, Optional

import absl
import apache_beam as beam
import tensorflow as tf
import tfx
from tfx import types
from tfx.components import StatisticsGen
from tfx.components.base import base_component, base_driver, base_executor, executor_spec
from tfx.components.example_gen import driver, utils
from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor
from tfx.components.example_gen.component import FileBasedExampleGen
from tfx.orchestration import data_types, pipeline
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext
from tfx.proto import example_gen_pb2
from tfx.types import Channel, artifact_utils, channel_utils, standard_artifacts
from tfx.types.component_spec import ChannelParameter, ExecutionParameter
from tfx.utils import proto_utils
from tfx.utils.dsl_utils import external_input

In [None]:
logger = logging.getLogger()
logger.setLevel(logging.CRITICAL)

## Define helper functions


In [None]:
%%writefile {"helpers.py"}

import tensorflow as tf


def _int64_feature(value):
    """Wrapper for inserting int64 features into Example proto."""
    if not isinstance(value, list):
        value = [value]
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))


def _bytes_feature(value):
    """Wrapper for inserting bytes features into Example proto."""
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def get_label_from_filename(filename):
    """ Function to set the label for each image. In our case, we'll use the file 
    path of a label indicator. Based on your initial data 
    Args:
      filename: string, full file path
    Returns:
      int - label
    Raises:
      NotImplementedError if not label category was detected
    
    """

    lowered_filename = filename.lower()
    if "dog" in lowered_filename:
        label = 0
    elif "cat" in lowered_filename:
        label = 1
    else:
        raise NotImplementedError("Found unknown image")
    return label
    

def _convert_to_example(image_buffer, label):
    """Function to convert image byte strings and labels into tf.Example structures
      Args:
        image_buffer: byte string representing the image
        label: int
      Returns:
        TFExample data structure containing the image (byte string) and the label (int encoded)
    """

    example = tf.train.Example(
        features=tf.train.Features(
            feature={
                'image/raw': _bytes_feature(image_buffer),
                'label': _int64_feature(label)
            }))
    return example


def get_image_data(filename):
    """Process a single image file.
    Args:
      filename: string, path to an image file e.g., '/path/to/example.JPG'.
    Returns:
      TFExample data structure containing the image (byte string) and the label (int encoded)
    """
    label = get_label_from_filename(filename)
    byte_content = tf.io.read_file(filename)
    rs = _convert_to_example(byte_content.numpy(), label)
    return rs


### Custom Component Specifications

https://github.com/tensorflow/tfx/blob/master/tfx/types/standard_artifacts.py

Difference between ChannelParameter and ExecutionParameter

In [None]:
class CustomIngestionComponentSpec(types.ComponentSpec):
    """ComponentSpec for Custom Ingestion Component."""
    
    PARAMETERS = {
        'input_base':
            ExecutionParameter(type=(str, Text))
    }
    INPUTS = {
    }
    OUTPUTS = {
        'examples': ChannelParameter(type=standard_artifacts.Examples),
    }

### Custom Component Executor

In [None]:
from helpers import get_image_data


class CustomIngestionExecutor(base_executor.BaseExecutor):
    """Executor for CustomIngestionComponent."""

    def Do(self, input_dict: Dict[Text, List[types.Artifact]],
          output_dict: Dict[Text, List[types.Artifact]],
          exec_properties: Dict[Text, Any]) -> None:

        input_base_uri = exec_properties[utils.INPUT_BASE_KEY]
        image_files = tf.io.gfile.listdir(input_base_uri)
        random.shuffle(image_files)

        train_images, eval_images = image_files[100:], image_files[:100]
        splits = [('train', train_images), ('eval', eval_images)]

        examples_artifact = artifact_utils.get_single_instance(
            output_dict[utils.EXAMPLES_KEY]
        )
        examples_artifact.split_names = artifact_utils.encode_split_names(['train', 'eval'])

        for split_name, images in splits:
            output_dir = artifact_utils.get_split_uri(
                output_dict[utils.EXAMPLES_KEY], split_name)

            Path(output_dir).mkdir(parents=True, exist_ok=True)
            tfrecords_filename = os.path.join(output_dir, 'images.tfrecords')
            options = tf.io.TFRecordOptions(compression_type=None)
            writer = tf.io.TFRecordWriter(tfrecords_filename, options=options)

            for image_filename in images:
                image_path = os.path.join(input_base_uri, image_filename)
                example = get_image_data(image_path)
                writer.write(example.SerializeToString())

### Custom Component Driver

In [None]:
class CustomIngestionDriver(base_driver.BaseDriver):
    """Custom driver for CustomIngestion component.

    This driver supports file based ExampleGen, it registers external file path as
    an artifact, similar to the use cases CsvExampleGen and ImportExampleGen.
    """

    def resolve_input_artifacts(
        self,
        input_channels: Dict[Text, types.Channel],
        exec_properties: Dict[Text, Any],
        driver_args: data_types.DriverArgs,
        pipeline_info: data_types.PipelineInfo,
    ) -> Dict[Text, List[types.Artifact]]:
        """Overrides BaseDriver.resolve_input_artifacts()."""
        del driver_args  # unused
        del pipeline_info  # unused

        input_config = example_gen_pb2.Input()
        input_dict = channel_utils.unwrap_channel_dict(input_channels)
        for input_list in input_dict.values():
            for single_input in input_list:
                self._metadata_handler.publish_artifacts([single_input])
                
        return input_dict

### Component Setup 

Putting all pieces together.

In [None]:
class CustomIngestionComponent(base_component.BaseComponent):
    """CustomIngestion Component."""
    SPEC_CLASS = CustomIngestionComponentSpec
    EXECUTOR_SPEC = executor_spec.ExecutorClassSpec(CustomIngestionExecutor)
    DRIVER_CLASS = CustomIngestionDriver

    def __init__(self,
                input_base: Optional[Text] = None,
                output_data: types.Channel = None):
      if not output_data:
          output_data = types.Channel(type=standard_artifacts.Examples)

      spec = CustomIngestionComponentSpec(
          input_base=input_base,
          examples=output_data,
      )
      super(CustomIngestionComponent, self).__init__(spec=spec)

## Basic Pipeline

In [None]:
context = InteractiveContext()

data_root = '/content/images'

ingest_images = CustomIngestionComponent(
    input_base=data_root
)
context.run(ingest_images)

In [None]:
filepath = '{}/train/images.tfrecords'.format(ingest_images.outputs['examples'].get()[0].uri)
dataset = tf.data.TFRecordDataset([filepath])

In [None]:
for data in dataset.take(1):
  print(tf.train.Example.FromString(data.numpy()))

In [None]:
statistics_gen = StatisticsGen(
    examples=ingest_images.outputs['examples'])
context.run(statistics_gen)

context.show(statistics_gen.outputs['statistics'])

## Implement a component by overwriting the component executor

![TFX **Component**](https://drive.google.com/uc?export=view&id=1Hg-iUp8UF5Jh3dpdL-htG-Cw5g7GKqF3)

### Thinks to know:

* Decorator `@beam.ptransform_fn`: https://github.com/apache/beam/blob/master/sdks/python/apache_beam/transforms/ptransform.py
* `BaseExampleGenExecutor` class: https://github.com/tensorflow/tfx/blob/v0.22.1/tfx/components/example_gen/base_example_gen_executor.py#L90-L243

In [None]:
from helpers import get_image_data


@beam.ptransform_fn 
def image_to_example(
      pipeline: beam.Pipeline,
      exec_properties: Dict[Text, Any],
      split_pattern: Text,
      ) -> beam.pvalue.PCollection:
    """Read jpeg files and transform to TF examples.

    Note that each input split will be transformed by this function separately.

    Args:
        pipeline: beam pipeline.
        input_dict: Input dict from input key to a list of Artifacts.
          - input_base: input dir that contains the image data.
        exec_properties: A dict of execution properties.
        split_pattern: Split.pattern in Input config, glob relative file pattern
          that maps to input files with root directory given by input_base.

    Returns:
        PCollection of TF examples.
    """
    input_base_uri = exec_properties[utils.INPUT_BASE_KEY]
    image_pattern = os.path.join(input_base_uri, split_pattern)
    absl.logging.info(
        'Processing input image data {} to TFExample.'.format(image_pattern))

    image_files = tf.io.gfile.glob(image_pattern)
    if not image_files:
        raise RuntimeError(
            'Split pattern {} does not match any files.'.format(image_pattern))

    return (
        pipeline
        | beam.Create(image_files)
        | 'ConvertImagesToBase64' >> beam.Map(lambda file: get_image_data(file))
    )

In [None]:
class ImageExampleGenExecutor(BaseExampleGenExecutor):
    """TFX example gen executor for processing jpeg format.

    Example usage:

      from tfx.components.example_gen.component import
      FileBasedExampleGen
      from tfx.utils.dsl_utils import external_input

      example_gen = FileBasedExampleGen(
          input=external_input("/content/PetImages/"),
          input_config=input_config,
          output_config=output,
          custom_executor_spec=executor_spec.ExecutorClassSpec(_Executor))
    """

    def GetInputSourceToExamplePTransform(self) -> beam.PTransform:
        """Returns PTransform for image to TF examples."""
        return image_to_example

## Building your ML Pipeline

In [None]:
pipeline_name = "dogs_cats_pipeline"

context = InteractiveContext(pipeline_name=pipeline_name)

In [None]:
output = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
        example_gen_pb2.SplitConfig.Split(
            name='train',
            hash_buckets=4
        ),
        example_gen_pb2.SplitConfig.Split(
            name='eval',
            hash_buckets=1
        )
    ])
)

input_config = example_gen_pb2.Input(splits=[
    example_gen_pb2.Input.Split(name='images', pattern='*.jpg'),
])

example_gen = FileBasedExampleGen(
    input=external_input("/content/images/"),
    input_config=input_config,
    output_config=output,
    custom_executor_spec=executor_spec.ExecutorClassSpec(ImageExampleGenExecutor))

In [None]:
context.run(example_gen)

In [None]:
statistics_gen = StatisticsGen(
    examples=example_gen.outputs['examples']
)

In [None]:
context.run(statistics_gen)

In [None]:
context.show(statistics_gen.outputs['statistics'])