In [None]:
import sys
import os

root_dir = os.getcwd().split(os.sep)[:-3]
root_dir = '/'.join(root_dir)
sys.path.append(root_dir)
from utils.helper_metastore import *
from utils.configurations.config import Config

### Custom data connector

To compute data statistics, TFDV provides several convenient methods for handling input data in various formats (e.g. TFRecord of tf.train.Example, CSV, DataFrame etc). If your data format is not in this list, you need to write a custom data connector for reading input data, and connect it with the TFDV core API for computing data statistics.

The TFDV core API for computing data statistics is a Beam PTransform that takes a PCollection of batches of input examples (a batch of input examples is represented as an Arrow RecordBatch), and outputs a PCollection containing a single DatasetFeatureStatisticsList protocol buffer.

Once you have implemented the custom data connector that batches your input examples in an Arrow RecordBatch, you need to connect it with the tfdv.GenerateStatistics API for computing the data statistics.

In [None]:
import os

import apache_beam as beam
import tensorflow as tf
import tensorflow_data_validation as tfdv
from tensorflow_metadata.proto.v0 import statistics_pb2
from tensorflow_data_validation.coders import tf_example_decoder

Let consider the scenario were we want to generate statistics of the dataset from text file. so, we decided to write custom conectore which follows the flow as:

   >Read_Text_File -> Serialize Row into tf.train.Example -> DecodeData -> GenerateStatistics

In [None]:
# Helper Functions to tf.train.Example
def _bytes_feature(value):
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

# Function which converts the row to tf.train.Example
def serialize_example(row):
    headers = ['PassengerId', 'Survived', 'Pclass', 'Name', 'Sex', 'Age', 'SibSp', 'Parch',
               'Ticket', 'Fare', 'Cabin', 'Embarked']
    integer_columns = ['PassengerId', 'Survived', 'Pclass', 'Age', 'SibSp', 'Parch']
    float_columns = ['Fare']
    row = row.split('|')
    feature = {}
    for idx in range(len(headers)):
        if headers[idx] in integer_columns:
            value = -999 if row[idx] == '' else row[idx]
            feature[headers[idx]] = _int64_feature(int(float(value)))
        elif headers[idx] in float_columns:
            value = -999.0 if row[idx] == '' else row[idx]
            feature[headers[idx]] = _float_feature(float(value))
        else:
            value = 'None' if row[idx] == '' else row[idx]
            feature[headers[idx]] = _bytes_feature(value.strip().encode())
            
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
root_dir = os.path.split(os.getcwd())[0]
dataset_name = Config.ADD_ONS_DATASET_NAME + '.txt'
INPUT_LOCATION = os.path.join(root_dir, 'data', dataset_name)
OUTPUT_LOCATION = os.path.join(root_dir, 'outputs')

if not os.path.exists(OUTPUT_LOCATION):
    os.makedirs(OUTPUT_LOCATION)
    
OUTPUT_LOCATION = os.path.join(OUTPUT_LOCATION, 'statistics.tfrecord')

In [None]:
with beam.Pipeline() as p:
    stats = (
    p | 'Readtxt' >> beam.io.ReadFromText(INPUT_LOCATION,
                                          skip_header_lines = 1)
      | 'Serialize to tf.Example' >> beam.Map(serialize_example)
      | 'DecodeData' >> tf_example_decoder.DecodeTFExample()
      | 'GenerateStatistics' >> tfdv.GenerateStatistics()
    )
    
    _ = (stats |  'WriteStatsOutput' >> tfdv.WriteStatisticsToTFRecord(OUTPUT_LOCATION))

In [None]:
stats = tfdv.load_statistics(OUTPUT_LOCATION)

schema = tfdv.infer_schema(stats)

In [None]:
tfdv.visualize_statistics(stats)

In [None]:
tfdv.display_schema(schema)

We can found some nan value(-999) in Age column. let consider we discus about this with our domain experts, they conformed that Age column is optional which will available 0.8% times in records.

TFDV automatical infers such column as required, we have to tweak this manually in schema file. Let see how to do that.

In [None]:
tfdv.get_feature(schema, 'Age').presence.min_fraction = 0.8

In [None]:
tfdv.display_schema(schema)

# Now you can see that Age column is marked as optional