In [1]:
import argparse
import sys
import tempfile

In [2]:
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.tensor_forest.client import eval_metrics
from tensorflow.contrib.tensor_forest.client import random_forest
from tensorflow.contrib.tensor_forest.python import tensor_forest
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.platform import app

In [3]:
FLAGS = None

In [4]:
def build_estimator(model_dir):
    """Build an estimator"""
    params = tensor_forest.ForestHParams(
        num_classes=10, num_features=784,
        num_trees=FLAGS.num_trees, max_nodes=FLAGS.max_nodes)
    graph_builder_class = tensor_forest.RandomForestGraphs
    if FLAGS.use_training_loss:
        graph_builder_class = tensor_forest.TrainingLossForest
        
    return estimator.SKCompat(random_forest.TensorForestEstimator(
        params, graph_builder_class=graph_builder_class,
        model_dir=model_dir))

In [5]:
def train_and_eval():
    """Train and evaluate the model"""
    model_dir = tempfile.mkdtemp() if not FLAGS.model_dir else FLAGS.model_dir
    print('model dir = %s'% model_dir)
    
    est = build_estimator(model_dir)
    
    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=False)
    
    est.fit(x=mnist.train.images, y=mnist.train.labels,
           batch_size=FLAGS.batch_size)
    metric_name = 'accuracy'
    metric = {metric_name:
             metric_spec.MetricSpec(
                 eval_metrics.get_metric(metric_name),
                 prediction_key=eval_metrics.get_prediction_key(metric_name))}
    
    results = est.score(x=mnist.test.images, y=mnist.test.labels,
                       batch_size=FLAGS.batch_size,
                       metrics=metric)
    for key in sorted(results):
        print('%s: %s'% (key, results[key]))
        

In [6]:
def main(_):
    train_and_eval()

In [None]:
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_dir',
        type=str,
        default='',
        help='Base directory for output models.'
    )
    
    parser.add_argument(
      '--data_dir',
      type=str,
      default='/tmp/data/',
      help='Directory for storing data'
    )
    parser.add_argument(
      '--train_steps',
      type=int,
      default=1000,
      help='Number of training steps.'
    )
    parser.add_argument(
      '--batch_size',
      type=str,
      default=1000,
      help='Number of examples in a training batch.'
      )
    parser.add_argument(
      '--num_trees',
      type=int,
      default=100,
      help='Number of trees in the forest.'
      )
    parser.add_argument(
      '--max_nodes',
      type=int,
      default=1000,
      help='Max total nodes in a single tree.'
      )
    parser.add_argument(
      '--use_training_loss',
      type=bool,
      default=False,
      help='If true, use training loss as termination criteria.'
        )
    FLAGS, unparsed = parser.parse_known_args()
    app.run(main=main, argv=[sys.argv[0]] + unparsed
           )



model dir = /tmp/tmpx57l4pc8
INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_environment': 'local', '_keep_checkpoint_max': 5, '_tf_random_seed': None, '_is_chief': True, '_task_id': 0, '_evaluation_master': '', '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1.0
}
, '_num_ps_replicas': 0, '_save_checkpoints_steps': None, '_save_summary_steps': 100, '_master': '', '_task_type': None, '_keep_checkpoint_every_n_hours': 10000, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f779149cef0>, '_save_checkpoints_secs': 600}
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracti