In [1]:
import tempfile
import glob

import pandas as pd
import numpy as np

import tensorflow as tf

import tensorflow_transform as tft
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow_transform.tf_metadata import dataset_metadata, dataset_schema

import apache_beam as beam
from apache_beam.io import tfrecordio

from IPython.display import display

In [2]:
tf.logging.set_verbosity(tf.logging.ERROR)

### Transform TFRecords with TFT

In [3]:
# schema for raw data
RAW_DATA_FEATURE = {
    'letters': tf.FixedLenFeature(shape=[1], dtype=tf.string),
    'yvar': tf.FixedLenFeature(shape=[1], dtype=tf.int64)
}

RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
    dataset_schema.from_feature_spec(RAW_DATA_FEATURE))

In [4]:
# transform data with tft
with beam.Pipeline() as pipeline:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
        coder = tft.coders.ExampleProtoCoder(RAW_DATA_METADATA.schema)

        data = (
            pipeline
            | 'Read' >> tfrecordio.ReadFromTFRecord('data/test.tfrecords')
            | 'Decode' >> beam.Map(coder.decode))

        transform_fn = (
            pipeline 
            | transform_fn_io.ReadTransformFn('data'))
        
        # NOTE: I'm still a little puzzled by this syntax; how does `transform_fn`(a pcoll) 
        #       become an argument for TransformDataset?
        transformed_data, _ = (
            ((data, RAW_DATA_METADATA), transform_fn)
            | 'Transform' >> beam_impl.TransformDataset())
        
        tf_transform_output = tft.TFTransformOutput('data')
        transformed_data_coder = tft.coders.ExampleProtoCoder(tf_transform_output.transformed_metadata.schema)

        _ = (
            transformed_data
            | 'Encode' >> beam.Map(transformed_data_coder.encode)
            | 'Write' >> tfrecordio.WriteToTFRecord('data/leads_transformed.tfrecords'))

### Inspect transformed TFRecords

In [5]:
# load data
# TODO: this function is useful; put into `djr-py`!
def fetch_tf_records(input_file_pattern, feature_spec, top=None):
    def input_fn():
        input_filenames = glob.glob(input_file_pattern)

        if not top:
            n = 0
            for f in input_filenames:
                n += sum(1 for _ in tf.python_io.tf_record_iterator(f))
        else:
            n = top
        
        ds = tf.data.TFRecordDataset(input_filenames)
        ds = ds.map(lambda x: tf.parse_single_example(x, feature_spec))
        ds = ds.batch(n).repeat(1)

        return ds.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        return sess.run(input_fn())


ds_pre = fetch_tf_records('data/test.tfrecords', RAW_DATA_FEATURE)

ds_post = fetch_tf_records('data/test_transformed.tfrecords*', 
                           tf_transform_output.transformed_feature_spec())

display(ds_pre)
display(ds_post)

{'letters': array([['A|B|C'],
        ['B|D'],
        ['D|E|A']], dtype=object), 'yvar': array([[1],
        [0],
        [1]])}

{u'letters': SparseTensorValue(indices=array([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [2, 0],
        [2, 1],
        [2, 2]]), values=array([ 2,  1, -1,  1,  0,  0, -1,  2]), dense_shape=array([3, 3])),
 u'yvar': array([[1],
        [0],
        [1]])}