In [4]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from six.moves import urllib

import numpy as np

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib

In [5]:
TRAINING_URL = 'http://download.tensorflow.org/data/abalone_train.csv'
TEST_URL = 'http://download.tensorflow.org/data/abalone_test.csv'
PREDICTION_URL = 'http://download.tensorflow.org/data/abalone_predict.csv'

TRAINING_FILENAME = TRAINING_URL.split('/')[-1]
TEST_FILENAME = TEST_URL.split('/')[-1]
PREDICTION_FILENAME = PREDICTION_URL.split('/')[-1]

def maybe_download(url, filename):
  if not os.path.exists(filename):
    urllib.request.urlretrieve(url, filename)
  dataset = tf.contrib.learn.datasets.base.load_csv_without_header(
    filename=filename, target_dtype=np.int, features_dtype=np.float32)
  return dataset
  
training_dataset = maybe_download(TRAINING_URL, TRAINING_FILENAME)
test_dataset = maybe_download(TEST_URL, TEST_FILENAME)
prediction_dataset = maybe_download(PREDICTION_URL, PREDICTION_FILENAME)

In [6]:
tf.logging.set_verbosity(tf.logging.INFO)

# Learning rate for the model
LEARNING_RATE = 0.001

def model_fn(features, targets, mode, params):
  """Model function for Estimator."""

  first_hidden_layer = tf.contrib.layers.relu(features, 10)
  second_hidden_layer = tf.contrib.layers.relu(first_hidden_layer, 10)
  output_layer = tf.contrib.layers.linear(second_hidden_layer, 1)

  # Reshape output layer to 1-dim Tensor to return predictions
  predictions = tf.reshape(output_layer, [-1])
  predictions_dict = {"ages": predictions}

  # Calculate loss using mean squared error
  loss = tf.losses.mean_squared_error(targets, predictions)

  # Calculate root mean squared error as additional eval metric
  eval_metric_ops = {
      "rmse": tf.metrics.root_mean_squared_error(
          tf.cast(targets, tf.float32), predictions)
  }

  train_op = tf.contrib.layers.optimize_loss(
      loss=loss,
      global_step=tf.contrib.framework.get_global_step(),
      learning_rate=params["learning_rate"],
      optimizer="SGD")

  return model_fn_lib.ModelFnOps(
      mode=mode,
      predictions=predictions_dict,
      loss=loss,
      train_op=train_op,
      eval_metric_ops=eval_metric_ops)

# Set model params
model_params = {"learning_rate": LEARNING_RATE}

# Instantiate Estimator
nn = tf.contrib.learn.Estimator(model_fn=model_fn, params=model_params)

def get_train_inputs():
  x = tf.constant(training_dataset.data)
  y = tf.constant(training_dataset.target)
  return x, y

# Fit
nn.fit(input_fn=get_train_inputs, steps=5000)

# Score accuracy
def get_test_inputs():
  x = tf.constant(test_dataset.data)
  y = tf.constant(test_dataset.target)
  return x, y

ev = nn.evaluate(input_fn=get_test_inputs, steps=1)
print("Loss: %s" % ev["loss"])
print("Root Mean Squared Error: %s" % ev["rmse"])

# Print out predictions
predictions = nn.predict(x=prediction_dataset.data, as_iterable=True)
for i, p in enumerate(predictions):
  print("Prediction %s: %s" % (i + 1, p["ages"]))

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': None, '_is_chief': True, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f3e885dbd90>, '_model_dir': '/tmp/tmpPraVMU', '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_session_config': None, '_tf_random_seed': None, '_environment': 'local', '_num_worker_replicas': 0, '_task_id': 0, '_save_summary_steps': 100, '_tf_config': gpu_options {
  per_process_gpu_memory_fraction: 1.0
}
, '_evaluation_master': '', '_master': ''}
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Saving checkpoints for 1 into /tmp/tmpPraVMU/model.ckpt.
INFO:tensorflow:loss = 110.505, step = 1
INFO:tensorflow:global_step/sec: 631.469
INFO:tensorflow:loss = 7.90882, step = 101 (0.161 sec)
INFO:tensorflow:global_step/sec: 638.908
INFO:tensorflow:loss = 7.28289, step = 201 (0.157 sec)
INF