In [1]:
import os
import tempfile
import glob

import pandas as pd
import numpy as np

import tensorflow as tf

import tensorflow_transform as tft
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow_transform.tf_metadata import dataset_metadata, dataset_schema

import apache_beam as beam
from apache_beam.io import tfrecordio

from IPython.display import display

In [2]:
tf.logging.set_verbosity(tf.logging.ERROR)

In [3]:
!rm -Rf data/transform_fn
!rm -Rf data/transformed_metadata

### Transform TFRecords

In [4]:
# TODO: this function is useful; put into `djr-py`!
csv = pd.read_csv('data/test.csv')
field_types = dict(csv.dtypes)
csv_records = csv.to_dict(orient='records')

with tf.python_io.TFRecordWriter('data/test.tfrecords') as writer:
    for row in csv_records:
        example = tf.train.Example()
        for k, v in row.items():
            if field_types[k] == 'int64':
                example.features.feature[k].int64_list.value.append(v)
            elif field_types[k] == 'float64':
                example.features.feature[k].float_list.value.append(v)
            else:
                example.features.feature[k].bytes_list.value.append(str(v).encode('utf-8'))
        writer.write(example.SerializeToString())

### Use TFT/Beam to transform data for model

In [5]:
# schema for raw data
RAW_DATA_FEATURE = {
    'letters': tf.FixedLenFeature(shape=[1], dtype=tf.string),
    'yvar': tf.FixedLenFeature(shape=[1], dtype=tf.int64)
}

RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
    dataset_schema.from_feature_spec(RAW_DATA_FEATURE))

In [6]:
# train our tft transformer
with beam.Pipeline() as pipeline:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
        coder = tft.coders.ExampleProtoCoder(RAW_DATA_METADATA.schema)

        data = (
            pipeline
            | 'Read' >> tfrecordio.ReadFromTFRecord('data/test.tfrecords')
            | 'Decode' >> beam.Map(coder.decode))

        # NOTE: vocab_filename specifies name of vocab file in `transform_fn/assets`
        def preprocessing_fn(inputs):
            letters = tf.string_split(tf.reshape(inputs['letters'], [-1]), '|')
            letters_indices = tft.compute_and_apply_vocabulary(letters, 
                                                               frequency_threshold=2,
                                                               vocab_filename='letters_vocab')
            return {
                'letters': letters_indices,
                'yvar': inputs['yvar']
            }

        (transformed_data, transformed_metadata), transform_fn = (
            (data, RAW_DATA_METADATA)
            | 'AnalyzeAndTransform' >> beam_impl.AnalyzeAndTransformDataset(preprocessing_fn))
        
        transformed_data_coder = tft.coders.ExampleProtoCoder(transformed_metadata.schema)

        _ = (
            transformed_data
            | 'Encode' >> beam.Map(transformed_data_coder.encode)
            | 'Write' >> tfrecordio.WriteToTFRecord('data/test_transformed.tfrecords'))

        _ = (
            transform_fn
            | 'WriteTransformFn' >> transform_fn_io.WriteTransformFn('data'))

### Inspect transformed TFRecords

In [7]:
# load data
# TODO: this function is useful; put into `djr-py`!
def fetch_tf_records(input_file_pattern, feature_spec, top=None):
    def input_fn():
        input_filenames = glob.glob(input_file_pattern)

        if not top:
            n = 0
            for f in input_filenames:
                n += sum(1 for _ in tf.python_io.tf_record_iterator(f))
        else:
            n = top
        
        ds = tf.data.TFRecordDataset(input_filenames)
        ds = ds.map(lambda x: tf.parse_single_example(x, feature_spec))
        ds = ds.batch(n).repeat(1)

        return ds.make_one_shot_iterator().get_next()

    with tf.Session() as sess:
        return sess.run(input_fn())

ds_pre = fetch_tf_records('data/test.tfrecords', RAW_DATA_FEATURE)

ds_post = fetch_tf_records('data/test_transformed.tfrecords*', 
                           transformed_metadata.schema.as_feature_spec())

display(ds_pre)
display(ds_post)

{'letters': array([['A|B|C'],
        ['B|D'],
        ['D|E|A']], dtype=object), 'yvar': array([[1],
        [0],
        [1]])}

{'letters': SparseTensorValue(indices=array([[0, 0],
        [0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [2, 0],
        [2, 1],
        [2, 2]]), values=array([ 2,  1, -1,  1,  0,  0, -1,  2]), dense_shape=array([3, 3])),
 'yvar': array([[1],
        [0],
        [1]])}