In [6]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys
import tempfile

import numpy

from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.platform import app


In [7]:
FLAGS = None

In [8]:
def build_estimator(model_dir):
  """Build an estimator."""
  params = tensor_forest.ForestHParams(
      num_classes=10,
      num_features=784,
      num_trees=FLAGS.num_trees,
      max_nodes=FLAGS.max_nodes)
  graph_builder_class = tensor_forest.RandomForestGraphs
  if FLAGS.use_training_loss:
    graph_builder_class = tensor_forest.TrainingLossForest
  return random_forest.TensorForestEstimator(
      params, graph_builder_class=graph_builder_class, model_dir=model_dir)

In [3]:
def train_and_eval():
  """Train and evaluate the model."""
  model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
  print('model directory = %s' % model_dir)

  est = build_estimator(model_dir)

  data_set = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)

  train_input_fn = numpy_io.numpy_input_fn(
      x={'images': data_set.train.images},
      y=data_set.train.labels.astype(numpy.int32),
      batch_size=FLAGS.batch_size,
      num_epochs=None,
      shuffle=True)
  est.fit(input_fn=train_input_fn, steps=None)

  metric_name = 'accuracy'
  metric = {
      metric_name:
          metric_spec.MetricSpec(
              eval_metrics.get_metric(metric_name),
              prediction_key=eval_metrics.get_prediction_key(metric_name))
  }

  test_input_fn = numpy_io.numpy_input_fn(
      x={'images': data_set.test.images},
      y=data_set.test.labels.astype(numpy.int32),
      num_epochs=1,
      batch_size=FLAGS.batch_size,
      shuffle=False)

  results = est.evaluate(input_fn=test_input_fn, metrics=metric)
  for key in sorted(results):
    print('%s: %s' % (key, results[key]))


In [4]:
def main(_):
  train_and_eval()


In [5]:
if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--model_dir',
      type=str,
      default='',
      help='Base directory for output models.'
  )
  parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/data/',
      help='Directory for storing data'
  )
  parser.add_argument(
      '--train_steps',
      type=int,
      default=1000,
      help='Number of training steps.'
  )
  parser.add_argument(
      '--batch_size',
      type=str,
      default=1000,
      help='Number of examples in a training batch.'
  )
  parser.add_argument(
      '--num_trees',
      type=int,
      default=100,
      help='Number of trees in the forest.'
  )
  parser.add_argument(
      '--max_nodes',
      type=int,
      default=1000,
      help='Max total nodes in a single tree.'
  )
  parser.add_argument(
      '--use_training_loss',
      type=bool,
      default=False,
      help='If true, use training loss as termination criteria.'
  )


NameError: name 'argparse' is not defined