# My Machine Learning Pipelines practice with TFX. Unfortunately, I couldn't complete registration on 'Google Cloud Platform' due to my Nigerian credit card not being acccepted, which means I cannot use GCP. 

• Data ingestion with ExampleGen

• Data validation with StatisticsGen, SchemaGen, and the ExampleValidator

• Data preprocessing with Transform

• Model training with Trainer

• Checking for previously trained models with ResolverNode

• Model analysis and validation with Evaluator

• Model deployments with Pusher

https://www.consumerfinance.gov/data-research/consumer-complaints/


Data was gotten from the above URL

In [2]:
# import Libraries here
import tensorflow as tf
from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext

import apache_beam as beam
import warnings
import os

warnings.filterwarnings('ignore')

In [3]:
# instantiate the PipeLine manager
context = InteractiveContext()




# Step 1: Data Ingestion

In [7]:
# import Libraries for Data Ingestion here
from tfx.v1.components import CsvExampleGen, ImportExampleGen
from tfx.components.example_gen.custom_executors import parquet_executor # parquet file executor
from tfx.components import FileBasedExampleGen # generic file loader component
from tfx.components.base import executor_spec
import csv


In [20]:
# import CSVfiles and run in Pipeline
example_gen = CsvExampleGen(input_base=r'files/consumer.csv')
# context.run(example_gen)

# import TFRecord using ImportExampleGen
example_gen_1 = ImportExampleGen(input_base='tfrecord file name')

# load the parquet file and override the executor
example_gen_2 = FileBasedExampleGen(input_base='parquet file name', custom_executor_spec=executor_spec.ExecutorClassSpec(parquet_executor.Executor))
# the above method can also be used to load AVRO files. Just import the AVRO-EXECUTOR


In [40]:
# functions to convert CSV data-records to tf.Feature
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value.encode()]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def clean_rows(row):
    if not row['zip_code']:
        row['zip_code'] = '99999'
    return row

def convert_zipcode_to_int(zipcode):
    if isinstance(zipcode, str) and 'XX' in zipcode:
        zipcode = zipcode.replace('XX', '00')
    int_zipcode = int(zipcode)
    return int_zipcode

#convert CSV file to TFRecord
original_data_file = r'files\consumer.csv'
tf_record_writer = tf.io.TFRecordWriter('consumer_complaints.tfrecord')

with open(original_data_file) as csv_file:
    reader = csv.DictReader(csv_file, delimiter=',', quotechar='"')
    for row in reader:
        row = clean_rows(row)
        example = tf.train.Example(
            features=tf.train.Features(feature={
                "product": _bytes_feature(row["product"]), 
                "sub_product": _bytes_feature(row["sub_product"]),
                "issue": _bytes_feature(row["issue"]),
                "sub_issue": _bytes_feature(row["sub_issue"]),
                "state": _bytes_feature(row["state"]),
                "zip_code": _int64_feature(convert_zipcode_to_int(row["zip_code"])),
                "company": _bytes_feature(row["company"]),
                "company_response": _bytes_feature(row["company_response"]),
                "consumer_complaint_narrative": _bytes_feature(row["consumer_complaint_narrative"]), 
                "timely_response": _bytes_feature(row["timely_response"]),
                "consumer_disputed": _bytes_feature(row["consumer_disputed"]),
            })
        )
        tf_record_writer.write(example.SerializeToString())
    tf_record_writer.close()
# the generated file can be imported with *ImportExampleGen



In [None]:
# for GOOGLE BigQuery----------------
import os
from tfx.extensions.google_cloud_big_query.example_gen import component

# setup GCP credentials and environment
os.environ['GOOGLE_APPLICATION_CREDENTIALS"'] = '/path/to/credenial_file.json'
query = """ 
    SELECT * FROM <project_id>.<database>.<table_name>
"""

example_gen = component.BigQueryExampleGen(query=query)


In [None]:
# splitting one datasets into subsets
from tfx.proto import example_gen_pb2

output = example_gen_pb2.Output(
    split_config=example_gen_pb2.SplitConfig(splits=[
        example_gen_pb2.SplitConfig.Split(name='train', hash_buckets=6), 
        example_gen_pb2.SplitConfig.Split(name='eval', hash_buckets=2), 
        example_gen_pb2.SplitConfig.Split(name='test', hash_buckets=2)
    ])
)

''' 
The INPUT_BASE must point to a folder that contains folders named 'train', 'eval', 
    and 'test' for this to work
'''
example_gen = CsvExampleGen(input_base=original_data_file, output_config=output)
context.run(example_gen) 

# to inspect the Artifacts produced by the EXAMPLE_GEN
for artifact in example_gen.outputs['examples'].get():
    print(artifact)

# to split existing Splits
input = example_gen_pb2.Input(splits=[
    example_gen_pb2.Input.Split(name='train', pattern='train/*'),
    example_gen_pb2.Input.Split(name='eval', pattern='eval/*'), 
    example_gen_pb2.Input.Split(name='test', pattern='test/*')
])

example_gen_4 = CsvExampleGen(input='external input path', input_config=input)

# for images, convert to TFRecord files. Do not decode
base_path = 'path/to/iamges'
filenames = os.listdir(base_path)

def generate_label_from_path(image_path):
    pass
# we would also use the BYTE & INT64 feature functions
with tf.io.TFRecordWriter('tfrecord filename') as writer:
    for img_path in filenames:
        image_path = os.path.join(base_path, img_path)
        try:
            raw_file = tf.io.read_file(image_path)
        except FileNotFoundError:
            print(f"File {image_path} not found.")
            continue
    example = tf.train.Example(features=tf.train.Features(feature={
        'image_raw': _bytes_feature(raw_file.numpy()), 
        'label': _int64_feature(generate_label_from_path(image_path))
    }))
    writer.write(example.SerializeToString())



# Step 2: Data Validation

In [35]:
# import Libraries here
import tensorflow_data_validation as tfdv


In [38]:
# stats can be generated either from CSV or TFRecord
stats = tfdv.generate_statistics_from_tfrecord('consumer_complaints.tfrecord')
stats



datasets {
  num_examples: 66799
  features {
    type: STRING
    string_stats {
      common_stats {
        num_non_missing: 66799
        min_num_values: 1
        max_num_values: 1
        avg_num_values: 1.0
        num_values_histogram {
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low

In [39]:
# generate SCHEMA
schema = tfdv.infer_schema(stats)
tfdv.display_schema(schema)
''' 
* Certain schema are generated for NUMEERICAL features while CATEGORICAL are different.
* Presence means whether the feature must be present in 100% of data examples.
* Valency is the minimum number of values required per training examples

'''

Unnamed: 0_level_0,Type,Presence,Valency,Domain
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
'company',BYTES,required,,-
'company_response',STRING,required,,'company_response'
'consumer_complaint_narrative',BYTES,required,,-
'consumer_disputed',BYTES,required,,-
'issue',STRING,required,,'issue'
'product',STRING,required,,'product'
'state',STRING,required,,'state'
'sub_issue',STRING,required,,'sub_issue'
'sub_product',STRING,required,,'sub_product'
'timely_response',STRING,required,,'timely_response'


Unnamed: 0_level_0,Values
Domain,Unnamed: 1_level_1
'company_response',"'Closed', 'Closed with explanation', 'Closed with monetary relief', 'Closed with non-monetary relief', 'Untimely response'"
'issue',"'APR or interest rate', 'Account opening, closing, or management', 'Account terms and changes', 'Adding money', 'Advertising and marketing', 'Advertising, marketing or disclosures', 'Application processing delay', 'Application, originator, mortgage broker', 'Applied for loan/did not receive money', 'Arbitration', 'Balance transfer', 'Balance transfer fee', 'Bankruptcy', 'Billing disputes', 'Billing statement', 'Can\'t contact lender', 'Can\'t repay my loan', 'Can\'t stop charges to bank account', 'Cash advance', 'Cash advance fee', 'Charged bank acct wrong day or amt', 'Charged fees or interest I didn\'t expect', 'Closing/Cancelling account', 'Communication tactics', 'Cont\'d attempts collect debt not owed', 'Convenience checks', 'Credit card protection / Debt protection', 'Credit decision / Underwriting', 'Credit determination', 'Credit line increase/decrease', 'Credit monitoring or identity protection', 'Credit reporting company\'s investigation', 'Customer service / Customer relations', 'Customer service/Customer relations', 'Dealing with my lender or servicer', 'Delinquent account', 'Deposits and withdrawals', 'Disclosure verification of debt', 'Disclosures', 'Excessive fees', 'False statements or representation', 'Fees', 'Forbearance / Workout plans', 'Fraud or scam', 'Getting a loan', 'Identity theft / Fraud / Embezzlement', 'Improper contact or sharing of info', 'Improper use of my credit report', 'Incorrect exchange rate', 'Incorrect information on credit report', 'Incorrect/missing disclosures or info', 'Late fee', 'Lender damaged or destroyed vehicle', 'Lender repossessed or sold the vehicle', 'Lender sold the property', 'Loan modification,collection,foreclosure', 'Loan servicing, payments, escrow account', 'Lost or stolen check', 'Lost or stolen money order', 'Making/receiving payments, sending money', 'Managing the line of credit', 'Managing the loan or lease', 'Managing, opening, or closing account', 'Money was not available when promised', 'Other', 'Other fee', 'Other service issues', 'Other transaction issues', 'Overdraft, savings or rewards features', 'Overlimit fee', 'Payment to acct not credited', 'Payoff process', 'Privacy', 'Problems caused by my funds being low', 'Problems when you are unable to pay', 'Received a loan I didn\'t apply for', 'Rewards', 'Sale of account', 'Settlement process and costs', 'Shopping for a line of credit', 'Shopping for a loan or lease', 'Taking out the loan or lease', 'Taking/threatening an illegal action', 'Transaction issue', 'Unable to get credit report/credit score', 'Unauthorized transactions/trans. issues', 'Unexpected/Other fees', 'Unsolicited issuance of credit card', 'Using a debit or ATM card', 'Wrong amount charged or received'"
'product',"'Bank account or service', 'Consumer Loan', 'Credit card', 'Credit reporting', 'Debt collection', 'Money transfers', 'Mortgage', 'Other financial service', 'Payday loan', 'Prepaid card', 'Student loan'"
'state',"'', 'AA', 'AE', 'AK', 'AL', 'AP', 'AR', 'AS', 'AZ', 'CA', 'CO', 'CT', 'DC', 'DE', 'FL', 'FM', 'GA', 'GU', 'HI', 'IA', 'ID', 'IL', 'IN', 'KS', 'KY', 'LA', 'MA', 'MD', 'ME', 'MI', 'MN', 'MO', 'MP', 'MS', 'MT', 'NC', 'ND', 'NE', 'NH', 'NJ', 'NM', 'NV', 'NY', 'OH', 'OK', 'OR', 'PA', 'PR', 'RI', 'SC', 'SD', 'TN', 'TX', 'UT', 'VA', 'VI', 'VT', 'WA', 'WI', 'WV', 'WY'"
'sub_issue',"'', 'Account status', 'Account terms', 'Account terms and changes', 'Applied for loan/did not receive money', 'Attempted to collect wrong amount', 'Attempted to/Collected exempt funds', 'Billing dispute', 'Called after sent written cease of comm', 'Called outside of 8am-9pm', 'Can\'t contact lender', 'Can\'t decrease my monthly payments', 'Can\'t get flexible payment options', 'Can\'t qualify for a loan', 'Can\'t stop charges to bank account', 'Can\'t temporarily postpone payments', 'Charged bank acct wrong day or amt', 'Charged fees or interest I didn\'t expect', 'Contacted employer after asked not to', 'Contacted me after I asked not to', 'Contacted me instead of my attorney', 'Debt is not mine', 'Debt resulted from identity theft', 'Debt was discharged in bankruptcy', 'Debt was paid', 'Don\'t agree with fees charged', 'Frequent or repeated calls', 'Having problems with customer service', 'Impersonated an attorney or official', 'Inadequate help over the phone', 'Indicated committed crime not paying', 'Indicated shouldn\'t respond to lawsuit', 'Information is not mine', 'Investigation took too long', 'Keep getting calls about my loan', 'Need information about my balance/terms', 'No notice of investigation status/result', 'Not disclosed as an attempt to collect', 'Not given enough info to verify debt', 'Payment to acct not credited', 'Personal information', 'Problem cancelling or closing account', 'Problem getting my free annual report', 'Problem getting report or credit score', 'Problem with fraud alerts', 'Problem with statement of dispute', 'Public record', 'Qualify for a better loan than offered', 'Received a loan I didn\'t apply for', 'Received bad information about my loan', 'Received marketing offer after opted out', 'Receiving unwanted marketing/advertising', 'Reinserted previously deleted info', 'Report improperly shared by CRC', 'Report shared with employer w/o consent', 'Right to dispute notice not received', 'Seized/Attempted to seize property', 'Sued w/o proper notification of suit', 'Sued where didn\'t live/sign for debt', 'Talked to a third party about my debt', 'Threatened arrest/jail if do not pay', 'Threatened to sue on too old debt', 'Threatened to take legal action', 'Trouble with how payments are handled', 'Used obscene/profane/abusive language'"
'sub_product',"'', '(CD) Certificate of deposit', 'Auto', 'Cashing a check without an account', 'Check cashing', 'Checking account', 'Conventional adjustable mortgage (ARM)', 'Conventional fixed mortgage', 'Credit card', 'Credit repair', 'Debt settlement', 'Domestic (US) money transfer', 'Electronic Benefit Transfer / EBT card', 'FHA mortgage', 'Federal student loan', 'Foreign currency exchange', 'General purpose card', 'Gift or merchant card', 'Government benefit payment card', 'Home equity loan or line of credit', 'I do not know', 'ID prepaid card', 'Installment loan', 'International money transfer', 'Medical', 'Mobile wallet', 'Money order', 'Mortgage', 'Non-federal student loan', 'Other (i.e. phone, health club, etc.)', 'Other bank product/service', 'Other mortgage', 'Other special purpose card', 'Pawn loan', 'Payday loan', 'Payroll card', 'Personal line of credit', 'Refund anticipation check', 'Reverse mortgage', 'Savings account', 'Title loan', 'Transit card', 'TravelerÃ¢Â€Â™s/CashierÃ¢Â€Â™s checks', 'VA mortgage', 'Vehicle lease', 'Vehicle loan'"
'timely_response',"'No', 'Yes'"


In [41]:
# recognizing Problems in Data using TFDV
train_stats = tfdv.generate_statistics_from_tfrecord('consumer_complaints.tfrecord')
val_stats = tfdv.generate_statistics_from_tfrecord('eval.tfrecord')

# visualize the difference between the EVAL & TRAINING DATASETS
tfdv.visualize_statistics(lhs_statistics=val_stats, rhs_statistics=train_stats, lhs_name='VAL_DS', rhs_name='TRAIN_DS')




In [42]:
# detect ANOMALIES
anomalies = tfdv.validate_statistics(val_stats, schema=schema)
tfdv.display_anomalies(anomalies)
# reason for not detecting ANOMALIES is because the EVAL set was taken from the Original Data
# As long as we understand how it works, we are good to go!

In [44]:
# updating the Schema
# schema = tfdv.load_schema_text('path to saved schema text') # do this if you're loading the schema from somewhere
sub_issue = tfdv.get_feature(schema, 'sub_issue')
''' 
Anomalies were detected in the original tutorial, so I would assume so. The COMPANY column was dropped because it contained
many null values.

Below we set the minimum acceptance to 90%, this means SUB_ISSUE column must have 90% of values present before it can be accepted.
'''
sub_issue.presence.min_fraction = 0.9 

# we can also update the list of US states to remove STATE OF ALASKA
state_domain = tfdv.get_domain(schema, 'state')
state_domain.value.remove('AK')

# save the SCHEMA
tfdv.write_schema_text(schema, 'my_schema')

# revalidate the Statistics to view the updated Schema
updated_anomalies = tfdv.validate_statistics(val_stats, schema)
tfdv.display_anomalies(updated_anomalies)
# As we can see below, it was noted that ALASKA was missing. We purposely removed it to show a working examaple

Unnamed: 0_level_0,Anomaly short description,Anomaly long description
Feature name,Unnamed: 1_level_1,Unnamed: 2_level_1
'state',Unexpected string values,Examples contain values missing from the schema: AK (<1%).


In [None]:
# check for Data Skew & Drift
tfdv.get_feature(schema, 'company').skew_comparator.infinity_norm.threshold = 0.01
skew_anomalies = tfdv.validate_statistics(train_stats, schema=schema, serving_statistics='')

In [47]:
# Performing statistics on a subset of the Dataset
from tensorflow_data_validation.utils import slicing_util 
from tensorflow_metadata.proto.v0 import statistics_pb2 


slice_fn1 = slicing_util.get_feature_value_slicer(features={'state': [b'CA']})
slicing_options = tfdv.StatsOptions(slice_functions=[slice_fn1])
slice_stats = tfdv.generate_statistics_from_csv(original_data_file, stats_options=slicing_options)
slice_stats




datasets {
  name: "All Examples"
  num_examples: 66799
  features {
    type: STRING
    string_stats {
      common_stats {
        num_non_missing: 66799
        min_num_values: 1
        max_num_values: 1
        avg_num_values: 1.0
        num_values_histogram {
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          buckets {
            low_value: 1.0
            high_value: 1.0
            sample_count: 6679.9
          }
          bu

In [53]:
# functions to copy sliced stats to the visualization

def get_sliced_stats(stats, slice_key):
    for sliced_stats in stats.datasets:
        if sliced_stats.name == slice_key:
            result = statistics_pb2.DatasetFeatureStatisticsList()
            result.datasets.add().CopyFrom(sliced_stats)
            return result
        print('Invalid Slice key')

def compare_slices(stats, slice_key1, slice_key2):
    lhs_stats = get_sliced_stats(stats, slice_key1)
    rhs_stats = get_sliced_stats(stats, slice_key2)
    tfdv.visualize_statistics(lhs_stats, rhs_stats)


compare_slices(slice_stats, 'state_CA', 'All Examples')

Invalid Slice key


In [55]:
# something is wrong somewhere, cannot figure it out now.
for ss in slice_stats.datasets:
    print(ss.name)

All Examples
state_CA


Processing Datasets with GCP.

I was unable to setup GCP with my Nigerian cloud credentials. So....

In [None]:
# I was unable to setup GCP with my credit card. 
from apache_beam.options.pipeline_options import (PipelineOptions, GoogleCloudOptions, StandardOptions, SetupOptions)

# setup Google cloud credentials
options = PipelineOptions()
google_cloud_options = options.view_as(GoogleCloudOptions)
google_cloud_options.project = 'GCP project ID'
google_cloud_options.job_name = 'Job name'
google_cloud_options.staging_location = 'Bucket staging location path'
google_cloud_options.temp_location = 'GCP bucket temp path'
options.view_as(StandardOptions).runner = 'DataFlowRunner'

# configure worker setup options e.g. download packages
setup_options = options.view_as(SetupOptions)
setup_options.extra_packages = ['enter package link']

# process can be started on local machine but is excuted on the cloud
data_set_path = 'GCP dataset bucket path'
output_path = 'GCP bucket path to store output'
tfdv.generate_statistics_from_tfrecord(data_set_path, output_path=output_path, pipeline_options=options)



In [57]:
# this is the Artifact we would pass to the StatisticsGen
example_gen.outputs['examples']

0,1
.type_name,Examples
._artifacts,"[0] function toggleTfxObject(element) {  var objElement = element.parentElement;  if (objElement.classList.contains('collapsed')) {  objElement.classList.remove('collapsed');  objElement.classList.add('expanded');  } else {  objElement.classList.add('collapsed');  objElement.classList.remove('expanded');  } } Artifact of type 'Examples' (uri: C:\Users\DELL\AppData\Local\Temp\tfx-interactive-2021-11-17T11_48_53.020822-shwg29u6\CsvExampleGen\examples\7) at 0x1f57a256278.type<class 'tfx.types.standard_artifacts.Examples'>.uriC:\Users\DELL\AppData\Local\Temp\tfx-interactive-2021-11-17T11_48_53.020822-shwg29u6\CsvExampleGen\examples\7.span0.split_names[""train"", ""eval"", ""test""].version0"

0,1
[0],"function toggleTfxObject(element) {  var objElement = element.parentElement;  if (objElement.classList.contains('collapsed')) {  objElement.classList.remove('collapsed');  objElement.classList.add('expanded');  } else {  objElement.classList.add('collapsed');  objElement.classList.remove('expanded');  } } Artifact of type 'Examples' (uri: C:\Users\DELL\AppData\Local\Temp\tfx-interactive-2021-11-17T11_48_53.020822-shwg29u6\CsvExampleGen\examples\7) at 0x1f57a256278.type<class 'tfx.types.standard_artifacts.Examples'>.uriC:\Users\DELL\AppData\Local\Temp\tfx-interactive-2021-11-17T11_48_53.020822-shwg29u6\CsvExampleGen\examples\7.span0.split_names[""train"", ""eval"", ""test""].version0"

0,1
.type,<class 'tfx.types.standard_artifacts.Examples'>
.uri,C:\Users\DELL\AppData\Local\Temp\tfx-interactive-2021-11-17T11_48_53.020822-shwg29u6\CsvExampleGen\examples\7
.span,0
.split_names,"[""train"", ""eval"", ""test""]"
.version,0


In [None]:
# integrating TFDV into our ML pipeline
from tfx.components import StatisticsGen, SchemaGen, ExampleValidator

# run the StatisticsGen
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])
context.run(statistics_gen)

# run the ScehmaGen, this only generates a Schema if one doesn't exist
schema_gen = SchemaGen(statistics=statistics_gen.outputs['statistics'], infer_feature_shape=True)
context.run(schema_gen)

# run the Example Validator which is the final step of this part to validate the data
example_validator = ExampleValidator(statistics=statistics_gen.outputs['statistics'], schema=schema_gen.outputs['schema'])
context.run(example_validator)


# Step 3: Data Processing with TRANSFORM

In [6]:
import tensorflow_transform as tft # for transformation
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils
import tensorflow_transform.beam as tft_beam

import tempfile

In [17]:
# pass the function to a TFT object, all operations within must be TF ops
def preprocessing_fn(inputs):
    x = inputs['x']
    x_normalized = tft.scale_to_0_1(x)
    return {'x_xf': x_normalized}

# some TFT functions
tft.scale_to_z_score() # mean of Zero, STD of 1
tft.bucketize() # bucketize a feature into bins
tft.pca() # dimensionality reduction
tft.compute_and_apply_vocabulary() # maps most frequent values to an index

# for NLP processing
tft.ngrams() # generates NGRAMS compliant with TF graphs
tft.bag_of_words() # uses ngrams to generate BOW vector
tft.tfidf() # generates vector with Token indices & TFIDF weights

# for image processing & computer vision problems
def process_image(raw_image):
    raw_image = tf.reshape(raw_image, [-1]) # reshape Image
    img_rgb = tf.io.decode_jpeg(raw_image, channels=3) # decode JPEG encoded image
    img_gray = tf.image.rgb_to_grayscale(img_rgb) # convert RGB to GrayScale
    img = tf.image.convert_image_dtype(raw_image, tf.float32)
    resized_img = tf.image.resize_with_pad(img, target_height=300, target_width=300)
    img_grayscale = tf.image.rgb_to_grayscale(resized_img)
    return tf.reshape(img_grayscale, [-1, 300, 300, 1])


In [None]:
from tensorflow_transform.tf_metadata import dataset_metadata, schema_utils
import tempfile
import tensorflow_transform.beam.impl as tft_beam


raw_data = [
    {'x': 1.20}, {'x': 2.99}, {'x': 100.00}
]

raw_data_metadata = dataset_metadata.DatasetMetadata(
    schema_utils.schema_from_feature_spec({
        'x': tf.io.FixedLenFeature([], tf.float32),
        }))

with beam.Pipeline() as pipeline:
    with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
        tfrecord_file = "practice_file1.tfrecord"
        raw_data = (
            pipeline | beam.io.ReadFromTFRecord(tfrecord_file))
        transformed_dataset, transform_fn = (
            (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
                preprocessing_fn))

In [None]:
# integrating TFT into ML Pipeline
# we first define helper functions to clean up our data

LABEL_KEY = 'consumer_disputed'

# Feature name, feature dimensionality
ONE_HOT_FEATURES = {
    "product": 11,
    "sub_product": 45,
    "company_response": 5,
    "state": 60,
    "issue": 90
}

# Feature name, Bucket count
BUCKET_FEATURES ={
    "zip_code": 10
}

# Feature name, value is unused
TEXT_FEATURES = {
    "consumer_complaint_narrative": None
}

# to differientiate between Input & Output values
def transformed_name(key):
    return key + '_xf'

# convert Sparse values to Dense values, and fill missing values
def fill_in_missing(x):
    default_value = '' if x.dtype == tf.string else 0
    if type(x) == tf.SparseTensor:
        x = tf.sparse.to_dense(
            tf.SparseTensor(x.indices, x.values, [x.dense_shape[0], 1], default_value)
        )
    return tf.squeeze(x, axis=1)

def convert_num_to_onehot(label_tensor, num_labels=2):
    one_hot_tensor = tf.one_hot(label_tensor, num_labels)
    return tf.reshape(one_hot_tensor, [-1, num_labels])

def convert_zip_code(zip_code):
    if zip_code == '':
        zip_code = '00000'
    zip_code = tf.strings.regex_replace(zip_code, r'x{0, 5}', '0')
    zip_code = tf.strings.to_number(zip_code, out_type=tf.float32)
    return zip_code

def preprocessing_fn(inputs):
    outputs = {}
    for key in ONE_HOT_FEATURES.keys():
        dim = ONE_HOT_FEATURES[key]
        index = tft.compute_and_apply_vocabulary(
            fill_in_missing(inputs[key]), top_k=dim+1
        )
        outputs[transformed_name(key)] = convert_num_to_onehot(index, num_labels=dim+1)

    # bucketize the "ZIP_CODE" column
    for key, bucket_count in BUCKET_FEATURES.items():
        temp_feature = tft.bucketize(
            convert_zip_code(fill_in_missing(inputs[key])),
            bucket_count, always_return_num_quantiles=False
        )
        outputs[transformed_name(key)] = convert_num_to_onehot(
            temp_feature, num_labels=bucket_count+1
        )

    # convert any Dense TEXT columns or labels to sparse
    for key in TEXT_FEATURES.keys():
        outputs[transformed_name(key)] = fill_in_missing(inputs[key])

    outputs[transformed_name(LABEL_KEY)] = fill_in_missing(inputs[LABEL_KEY])
    return outputs


In [None]:
'''
The TRANSFORM component expects the transformation code to be in a seperate Python file.
The name of the Module file can be set by the user but the entry point PREPROCESSING_FN()
needs to be contained in the file and the function cannot be renamed.
'''

transform = Transform(
    examples=example_gen.outputs['examples'],
    schema = schema_gen.outputs['schema'],
    module_file = os.path.abspath('module.py')
)
context.run(transform)

# Step 4: Model training with Trainer

In [3]:
import tensorflow_hub as tf_hub
from tfx.components import Trainer
from tfx.components.base import executor_spec # imported in Step 1
from tfx.components.trainer.executor import GenericExecutor
from tfx.proto import trainer_pb2

In [None]:

def get_model():
    # One-hot categorical features
    input_features = []
    for key, dim in ONE_HOT_FEATURES.items():
        input_features.append(
            tf.keras.Input(shape=(dim + 1,), name=transformed_name(key))
        )
    # Adding bucketized features
    for key, dim in BUCKET_FEATURES.items():
        input_features.append(
            tf.keras.Input(shape=(dim + 1), name=transformed_name(key))
        )
    # adding text input features
    input_texts = []
    for key in TEXT_FEATURES.keys():
        input_texts.append(
            tf.keras.Input(shape=(1,), name=transformed_name(key), dtype=tf.string)
        )
    inputs = input_features + input_texts

    # Embed text features
    MODULE_URL = "https://tfhub.dev/google/universal-sentence-encoder/4"
    embed = tf_hub.KerasLayer(MODULE_URL)
    reshaped_narrative = tf.reshape(input_texts[0], [-1])
    embed_narrative = embed(reshaped_narrative)
    deep_ff = tf.keras.layers.Reshape((512, ), input_shape=(1, 512))(embed_narrative)

    deep = tf.keras.layers.Dense(256, activation='relu')(deep_ff)
    deep = tf.keras.layers.Dense(64, activation='relu')(deep)
    deep = tf.keras.layers.Dense(16, activation='relu')(deep)

    wide_ff = tf.keras.layers.concatenate(input_features)
    wide = tf.keras.layers.Dense(16, activation='relu')(wide_ff)

    both = tf.keras.layers.concatenate([deep, wide])

    output = tf.keras.layers.Dense(1, activation='sigmoid')(both)
    keras_model = tf.keras.models.Model(inputs, output)

    keras_model.compile(
        optimizer = tf.keras.optimizers.Adam(learning_rate=0.001), loss='binary_crossentropy',
        metrics = [tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.TruePositives()]
    )
    return keras_model

'''
Just like the TRANSFORM expects a *preprocessing_fn, the TRAINER expects a *run_fn*
'''
LABEL_KEY = 'labels'

def _gzip_reader_fn(filenames):
    return tf.data.TFRecordDataset(filenames, compression_type='GZIP')

def input_fn(file_pattern, tf_transform_output, batch_size=32):
    transformed_feature_spec = (tf_transform_output.transformed_feature_spec().copy())
    dataset = tf.data.experimental.make_batched_features_dataset(
        file_pattern=file_pattern, batch_size=batch_size,
        features=transformed_feature_spec, reader=_gzip_reader_fn,
        label_key=transformed_name(LABEL_KEY)
    )
    return dataset

def get_serve_tf_examples_fn(model, tf_transform_output):
    model.tft_layer = tf_transform_output.transform_features_layer()

    @tf.function
    def serve_tf_examples_fn(serialized_tf_examples):
        # load the Preprocessing graph
        feature_spec = tf_transform_output.raw_feature_spec()
        feature_spec.pop(LABEL_KEY)
        # parse the raw TF.EXAMPLE records from the request
        parsed_features = tf.io.parse_example(
            serialized_tf_examples, feature_spec
        )
        # apply the preprocessing transformation to the raw data
        transformed_features = model.tft_layer(parsed_features)
        # perform predictions with preprocessed data
        outputs = model(transformed_features)
        return {'outputs': outputs}
    return serve_tf_examples_fn

def run_fn(fn_args):
    tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)
    train_dataset = input_fn(fn_args.train_files, tf_transform_output)
    eval_dataset = input_fn(fn_args.eval_files, tf_transform_output)

    model = get_model()
    # TFX trainer uses "Training steps" instead of EPOCHS
    model.fit(
        train_dataset, steps_per_epoch=fn_args.train_steps, validation_data=eval_dataset,
        validation_steps=fn_args.eval_steps
    )
    # model export
    signatures = {
        'serving_default': get_serve_tf_examples_fn(
            model, tf_transform_output
        ).get_concrete_function(tf.TensorSpec(
            shape=[None], dtype=tf.string, name='examples'
        ))
    }
    model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)


In [None]:
# Trainer component

TRAINING_STEPS = 1000
EVALUATION_STEPS = 100

trainer = Trainer(
    module_file = os.path.abspath('module.py')
    custom_executor_spec=executor_spec.executor_spec.ExecutorClassSpec(GenericExecutor),
    transformed_examples=transform.outputs['transformed_examples'],
    transform_graph=tranform.outputs['transform_graph'],
    schema=schema.outputs['schema'],
    train_args=trainer_pb2.trainer_pb2.TrainArgs(num_steps=TRAINING_STEPS),
    eval_args=trainer_pb2.EvalArgs(num_steps)
)




