Tensorflow Estimators can be created from new or existing `tf.Keras` models.

In [0]:
!pip install -q tf-nightly

In [16]:
import tensorflow as tf
import numpy as np
import tensorflow_datasets as tfds
import pandas as pd

print("Tensorflow Version: {}".format(tf.__version__))
print("Eager Mode: {}".format(tf.executing_eagerly()))
print("GPU {} available".format("is" if tf.config.experimental.list_physical_devices("GPU") else "not"))

Tensorflow Version: 2.2.0-dev20200113
Eager Mode: True
GPU is available


# Create a Simple tf.Keras Model

In [0]:
def build_keras_model():
  def build_model(inputs):
    x = tf.keras.layers.Dense(units=16, activation='elu', 
                              input_shape=[4], name='input')(inputs)
    x = tf.keras.layers.Dropout(0.2)(x)
    y = tf.keras.layers.Dense(units=1, activation='sigmoid', name='output')(x)
    return y

  inputs = tf.keras.Input(shape=[4])
  outputs = build_model(inputs)
  model = tf.keras.Model(inputs, outputs)
  return model

In [4]:
model = build_keras_model()
model.summary()

model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 4)]               0         
_________________________________________________________________
input (Dense)                (None, 16)                80        
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
output (Dense)               (None, 1)                 17        
Total params: 97
Trainable params: 97
Non-trainable params: 0
_________________________________________________________________


# Create an Input Function

In [0]:
def input_fn():
  split = tfds.Split.TRAIN
  dataset = tfds.load('iris', split=split, as_supervised=True)
  dataset = dataset.map(lambda feats, labels: (feats, labels))
  dataset = dataset.repeat(None).batch(32)
  return dataset

In [12]:
for feats, labels in input_fn().take(1):
  print(feats)
  print(labels)

tf.Tensor(
[[6.1 2.8 4.7 1.2]
 [5.7 3.8 1.7 0.3]
 [7.7 2.6 6.9 2.3]
 [6.  2.9 4.5 1.5]
 [6.8 2.8 4.8 1.4]
 [5.4 3.4 1.5 0.4]
 [5.6 2.9 3.6 1.3]
 [6.9 3.1 5.1 2.3]
 [6.2 2.2 4.5 1.5]
 [5.8 2.7 3.9 1.2]
 [6.5 3.2 5.1 2. ]
 [4.8 3.  1.4 0.1]
 [5.5 3.5 1.3 0.2]
 [4.9 3.1 1.5 0.1]
 [5.1 3.8 1.5 0.3]
 [6.3 3.3 4.7 1.6]
 [6.5 3.  5.8 2.2]
 [5.6 2.5 3.9 1.1]
 [5.7 2.8 4.5 1.3]
 [6.4 2.8 5.6 2.2]
 [4.7 3.2 1.6 0.2]
 [6.1 3.  4.9 1.8]
 [5.  3.4 1.6 0.4]
 [6.4 2.8 5.6 2.1]
 [7.9 3.8 6.4 2. ]
 [6.7 3.  5.2 2.3]
 [6.7 2.5 5.8 1.8]
 [6.8 3.2 5.9 2.3]
 [4.8 3.  1.4 0.3]
 [4.8 3.1 1.6 0.2]
 [4.6 3.6 1.  0.2]
 [5.7 4.4 1.5 0.4]], shape=(32, 4), dtype=float32)
tf.Tensor([1 0 2 1 1 0 1 2 1 1 2 0 0 0 0 1 2 1 1 2 0 2 0 2 2 2 2 2 0 0 0 0], shape=(32,), dtype=int64)


# Create an Estimator from the TF.Keras Model

A tf.keras.Model can firstly be converted into an `tf.estimator.Estimator` via the `tf.keras.estimator.model_to_estimator` API and then be trained via the `tf.estimator` API.

Train the model and evaluate it on the dataset.

In [20]:
model_dir = '/tmp/tfkeras_example/'
!rm -rf {model_dir}
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, 
                                                        model_dir=model_dir)
!ls -l {model_dir}/keras

keras_estimator.train(input_fn=input_fn, steps=25)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print(pd.Series(eval_result))

INFO:tensorflow:Using default config.


INFO:tensorflow:Using default config.


INFO:tensorflow:Using the Keras model provided.


INFO:tensorflow:Using the Keras model provided.


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfkeras_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


INFO:tensorflow:Using config: {'_model_dir': '/tmp/tfkeras_example/', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


total 16
-rw-r--r-- 1 root root   89 Jan 14 09:30 checkpoint
-rw-r--r-- 1 root root 1188 Jan 14 09:30 keras_model.ckpt.data-00000-of-00002
-rw-r--r-- 1 root root 2204 Jan 14 09:30 keras_model.ckpt.data-00001-of-00002
-rw-r--r-- 1 root root 1193 Jan 14 09:30 keras_model.ckpt.index
INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tfkeras_example/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='/tmp/tfkeras_example/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: /tmp/tfkeras_example/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting from: /tmp/tfkeras_example/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfkeras_example/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into /tmp/tfkeras_example/model.ckpt.


INFO:tensorflow:loss = 3.1260388, step = 0


INFO:tensorflow:loss = 3.1260388, step = 0


INFO:tensorflow:Saving checkpoints for 25 into /tmp/tfkeras_example/model.ckpt.


INFO:tensorflow:Saving checkpoints for 25 into /tmp/tfkeras_example/model.ckpt.


INFO:tensorflow:Loss for final step: 2.1776383.


INFO:tensorflow:Loss for final step: 2.1776383.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2020-01-14T09:30:06Z


INFO:tensorflow:Starting evaluation at 2020-01-14T09:30:06Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from /tmp/tfkeras_example/model.ckpt-25


INFO:tensorflow:Restoring parameters from /tmp/tfkeras_example/model.ckpt-25


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Inference Time : 0.27480s


INFO:tensorflow:Inference Time : 0.27480s


INFO:tensorflow:Finished evaluation at 2020-01-14-09:30:06


INFO:tensorflow:Finished evaluation at 2020-01-14-09:30:06


INFO:tensorflow:Saving dict for global step 25: accuracy = 0.33125, global_step = 25, loss = 1.6237723


INFO:tensorflow:Saving dict for global step 25: accuracy = 0.33125, global_step = 25, loss = 1.6237723


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tfkeras_example/model.ckpt-25


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: /tmp/tfkeras_example/model.ckpt-25


accuracy        0.331250
loss            1.623772
global_step    25.000000
dtype: float64
