In [1]:
import tempfile

import pandas as pd
import numpy as np

import tensorflow as tf

import tensorflow_transform as tft
from tensorflow_transform.beam import impl as beam_impl
from tensorflow_transform.beam.tft_beam_io import transform_fn_io
from tensorflow_transform.tf_metadata import dataset_metadata, dataset_schema

import apache_beam as beam
from apache_beam.io import tfrecordio

In [2]:
tf.logging.set_verbosity(tf.logging.ERROR)

In [3]:
!rm -Rf data/transform_fn
!rm -Rf data/transformed_metadata

### Convert CSV into TFRecords

In [4]:
# TODO: this function is useful; put into `mobe-py`!
csv = pd.read_csv('data/leads.csv')
field_types = dict(csv.dtypes)
csv_records = csv.to_dict(orient='records')

with tf.python_io.TFRecordWriter('data/leads.tfrecords') as writer:
    for row in csv_records:
        example = tf.train.Example()
        for k, v in row.items():
            if field_types[k] == 'int64':
                example.features.feature[k].int64_list.value.append(v)
            elif field_types[k] == 'float64':
                example.features.feature[k].float_list.value.append(v)
            else:
                example.features.feature[k].bytes_list.value.append(str(v).encode('utf-8'))
        writer.write(example.SerializeToString())

### Train TFT transformer and save

In [5]:
# schema for raw data
RAW_DATA_FEATURE = {
    'dx': tf.FixedLenFeature(shape=[1], dtype=tf.string),
    'enrolled': tf.FixedLenFeature(shape=[1], dtype=tf.int64)
}

RAW_DATA_METADATA = dataset_metadata.DatasetMetadata(
    dataset_schema.from_feature_spec(RAW_DATA_FEATURE))

In [6]:
# train our tft transformer
with beam.Pipeline() as pipeline:
    with beam_impl.Context(temp_dir=tempfile.mkdtemp()):
        coder = tft.coders.ExampleProtoCoder(RAW_DATA_METADATA.schema)

        data = (
            pipeline
            | 'Read' >> tfrecordio.ReadFromTFRecord('data/leads.tfrecords')
            | 'Decode' >> beam.Map(coder.decode))

        # NOTE: vocab_filename specifies name of vocab file in `transform_fn/assets`
        def preprocessing_fn(inputs):
            codes = tf.string_split(tf.reshape(inputs['dx'], [-1]), '|')
            codes_indices = tft.compute_and_apply_vocabulary(codes, 
                                                             frequency_threshold=2,
                                                             vocab_filename='dx_vocab')
            return {
                'dx': codes_indices,
                'enrolled': inputs['enrolled']
            }

        transform_fn = (
            (data, RAW_DATA_METADATA)
            | 'Analyze' >> beam_impl.AnalyzeDataset(preprocessing_fn))
        
        _ = (
            transform_fn
            | 'WriteTransformFn' >> transform_fn_io.WriteTransformFn('data'))