## TensorFlow Development Template
Thomas Crosley, October 2017

This is a template for developing TensorFlow models using several high-level APIs.
<ul>
    <li><a href=''></a></li>
    <li><a href='https://www.tensorflow.org/api_docs/python/tf/estimator/Estimator'>Estimators</a> --- <a href='https://www.tensorflow.org/extend/estimators'>Creating Estimators</a></li>  
    <li><a href='https://www.tensorflow.org/api_docs/python/tf/contrib/learn/Experiment'>Experiment</a></li>
</ul>

In [9]:
# Imports
import tensorflow as tf
import numpy as np
import time, datetime

In [None]:
# Reproducibility
tf.set_random_seed(0) # Graph level random seed
np.set_random_seed(0)

In [None]:
# Hyperparameters
train_file = 'train.tfrecord'
valid_file = 'valid.tfrecord'

ts = time.time()
timestamp = datetime.datetime.fromtimestamp(ts).strftime('%Y-%m-%d_%H-%M-%S')
output_dir = '/tmp/' + timestamp

hyper_params = {'batch_size' : 32, 'learning_rate': 0.001}

In [6]:
# Create Dataset
def parser(record):
    # Parse the TF record
    parsed = tf.parse_single_example(record, features={})

    # Load the data and format it
    
    # Data augmentation
    
    return data

def load_dataset(tfrecord):
    # Load the dataset
    dataset = tf.contrib.data.TFRecordDataset(tfrecord)

    # Parse the tf record entries
    dataset = dataset.map(parser, num_threads=8, output_buffer_size=1024)

    # Shuffle the data, batch it and run this for multiple epochs
    dataset = dataset.shuffle(buffer_size=10000)
    dataset = dataset.batch(batch_size)
    dataset = dataset.repeat()
    return dataset

train_data = load_dataset(train_file)
valid_data = load_dataset(valid_file)


def train_input_fn():
    

def eval_input_fn():
    

    

In [2]:
# Define Model
tf.reset_default_graph() # Useful if developing in Jupyter to clear previous test

'''
    Defines the model function passed into tf.estimator

    1. Configure the model via TensorFlow operations
    2. Define the loss function for training/evaluation
    3. Define the training operation/optimizer
    4. Generate predictions
    5. Return predictions/loss/train_op/eval_metric_ops in EstimatorSpec object

    Inputs:
        features: A dict containing the features passed to the model via input_fn
        labels: A Tensor containing the labels passed to the model via input_fn
        mode: One of the following tf.estimator.ModeKeys string values indicating
               the context in which the model_fn was invoked 
                  - tf.estimator.ModeKeys.TRAIN ---> est.train()
                  - tf.estimator.ModeKeys.EVAL, ---> est.evaluate()
                  - tf.estimator.ModeKeys.PREDICT -> est.predict()

    Outputs:
        tf.EstimatorSpec that defines the model
'''
def model(features, labels, mode, params):
    # 1. Define model structure
    # ...
    predictions_dict = {'output': output}
    
    # 2. Define the loss functions
    loss = ...
    
    # 2.1 Additional metrics for evaluation
    eval_metric_ops = {"rmse": tf.metrics.root_mean_squared_error(
          tf.cast(labels, tf.float64), output)}
    
    # 3. Define optimizer
    optimizer = tf.train.AdamOptimizer(learning_rate=params['learning_rate'])
    train_op = optimizer.minimize(loss=loss, global_step=tf.train.get_global_step())
    
    # 4. Generate predictions
    predictions = output
    
    # 5. Return EstimatorSpec
    return EstimatorSpec(mode, predictions, loss, train_op, eval_metric_ops)
    
estimator = tf.estimator(model_fn=model, model_dir=output_dir, config=config, params=hyper_params)

In [3]:
# Set up TensorBoard
merged_summary = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(output_dir + '/TB/train', tf.get_default_graph())
valid_writer = tf.summary.FileWriter(output_dir + '/TB/valid', )


In [8]:
# Session setup
sess = tf.Session()
sess.run(tf.global_variables_initializer())

In [4]:
# Training Loop
experiment = tf.contrib.learn.Experiment(
    estimator=estimator,
    train_input_fn=train_input_fn,
    eval_input_fn=eval_input_fn,
    train_steps=params.train_steps,  # Minibatch steps
    min_eval_frequency=params.min_eval_frequency,  # Eval frequency
    train_monitors=[train_input_hook],  # Hooks for training
    eval_hooks=[eval_input_hook],  # Hooks for evaluation
    eval_steps=None  # Use evaluation feeder until its empty
)

In [5]:
# Evaluation Loop


In [None]:
# Save extra files for reproducibility