In [1]:
import tempfile
import glob

import pandas as pd
import numpy as np

import tensorflow as tf

import tensorflow_transform as tft
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow_transform.tf_metadata import dataset_metadata, dataset_schema

import apache_beam as beam
from apache_beam.io import tfrecordio

In [2]:
tf.logging.set_verbosity(tf.logging.ERROR)

### Transform TFRecords with TFT

In [3]:
# schema for raw data
RAW_DATA_FEATURE = {
    'dx': tf.FixedLenFeature(shape=[1], dtype=tf.string),
    'enrolled': tf.FixedLenFeature(shape=[1], dtype=tf.int64)
}

RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
    dataset_schema.from_feature_spec(RAW_DATA_FEATURE))

In [4]:
# transform data with tft
with beam.Pipeline() as pipeline:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
        coder = tft.coders.ExampleProtoCoder(RAW_DATA_METADATA.schema)

        data = (
            pipeline
            | 'Read' >> tfrecordio.ReadFromTFRecord('data/leads.tfrecords')
            | 'Decode' >> beam.Map(coder.decode))

        transform_fn = (
            pipeline 
            | transform_fn_io.ReadTransformFn('data'))
        
        # NOTE: I'm still a little puzzled by this syntax; how does `transform_fn`(a pcoll) 
        #       become an argument for TransformDataset?
        transformed_data, _ = (
            ((data, RAW_DATA_METADATA), transform_fn)
            | 'Transform' >> beam_impl.TransformDataset())
        
        tf_transform_output = tft.TFTransformOutput('data')
        transformed_data_coder = tft.coders.ExampleProtoCoder(tf_transform_output.transformed_metadata.schema)

        _ = (
            transformed_data
            | 'Encode' >> beam.Map(transformed_data_coder.encode)
            | 'Write' >> tfrecordio.WriteToTFRecord('data/leads_transformed.tfrecords'))



### Inspect transformed TFRecords

In [5]:
# load data
# TODO: this function is useful; put into `mobe-py`!
def fetch_tf_records(input_file_pattern, feature_spec, top=None):
    input_filenames = glob.glob(input_file_pattern)
    
    if not top:
        n = 0
        for f in input_filenames:
            n += sum(1 for _ in tf.python_io.tf_record_iterator(f))
    else:
        n = top
    
    ds = tf.data.TFRecordDataset(input_filenames)
    ds = ds.map(lambda x: tf.parse_single_example(x, feature_spec))
    ds = ds.batch(n)
    
    return ds.make_one_shot_iterator().get_next()


ds_pre = fetch_tf_records('data/leads.tfrecords', RAW_DATA_FEATURE)

ds_post = fetch_tf_records('data/leads_transformed.tfrecords*', 
                           tf_transform_output.transformed_feature_spec())

with tf.Session() as sess:
    print(sess.run(ds_pre))
    print(sess.run(ds_post))

{'dx': array([['A|B|C'],
       ['D|A'],
       ['E|B|C']], dtype=object), 'enrolled': array([[1],
       [0],
       [1]])}
{u'dx': SparseTensorValue(indices=array([[0, 0],
       [0, 1],
       [0, 2],
       [1, 0],
       [1, 1],
       [2, 0],
       [2, 1],
       [2, 2]]), values=array(['A', 'B', 'C', 'D', 'A', 'E', 'B', 'C'], dtype=object), dense_shape=array([3, 3])), u'enrolled': array([[1],
       [0],
       [1]])}
