##### Copyright 2019 The TensorFlow Authors.

Taken from https://github.com/tensorflow/docs/blob/r2.0rc/site/en/r2/tutorials/estimators/keras_model_to_estimator.ipynb

# How-to create an Estimator from a Keras model

## Overview

TensorFlow Estimators are fully supported in TensorFlow, and can be created from new and existing `tf.keras` models. This tutorial contains a complete, minimal example of that process.

## Setup

In [1]:
!pip3 install -q tensorflow-datasets

In [1]:
import tensorflow as tf

import numpy as np
import tensorflow_datasets as tfds
tf.__version__

'2.0.0'

### Create a simple Keras model.

In Keras, you assemble *layers* to build *models*. A model is (usually) a graph
of layers. The most common type of model is a stack of layers: the
`tf.keras.Sequential` model.

To build a simple, fully-connected network (i.e. multi-layer perceptron):

In [2]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

Compile the model and get a summary.

In [3]:
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense (Dense)                (None, 16)                80        
_________________________________________________________________
dropout (Dropout)            (None, 16)                0         
_________________________________________________________________
dense_1 (Dense)              (None, 1)                 17        
Total params: 97
Trainable params: 97
Non-trainable params: 0
_________________________________________________________________


### Create an input function

Use the [Datasets API](../../guide/data.md) to scale to large datasets
or multi-device training.

Estimators need control of when and how their input pipeline is built. To allow this, they require an "Input function" or `input_fn`. The `Estimator` will call this function with no arguments. The `input_fn` must return a `tf.data.Dataset`.

In [4]:
def input_fn():
    split = tfds.Split.TRAIN
    dataset = tfds.load('iris', split=split, as_supervised=True)
    dataset = dataset.map(lambda features, labels: ({'dense_input':features}, labels))
    dataset = dataset.batch(32).repeat()
    return dataset

Test out your `input_fn`

In [5]:
for features_batch, labels_batch in input_fn().take(1):
    print(features_batch)
    print(labels_batch)

[1mDownloading and preparing dataset iris (4.44 KiB) to /home/jupyter/tensorflow_datasets/iris/1.0.0...[0m


HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Completed...', max=1, style=ProgressStyl…

HBox(children=(IntProgress(value=1, bar_style='info', description='Dl Size...', max=1, style=ProgressStyle(des…









HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



HBox(children=(IntProgress(value=0, description='Shuffling...', max=1, style=ProgressStyle(description_width='…

Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


Instructions for updating:
Use eager execution and: 
`tf.data.TFRecordDataset(path)`


HBox(children=(IntProgress(value=1, bar_style='info', description='Reading...', max=1, style=ProgressStyle(des…

HBox(children=(IntProgress(value=0, description='Writing...', max=150, style=ProgressStyle(description_width='…

[1mDataset iris downloaded and prepared to /home/jupyter/tensorflow_datasets/iris/1.0.0. Subsequent calls will reuse this data.[0m
{'dense_input': <tf.Tensor: id=200, shape=(32, 4), dtype=float32, numpy=
array([[6.1, 2.8, 4.7, 1.2],
       [5.7, 3.8, 1.7, 0.3],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.9, 4.5, 1.5],
       [6.8, 2.8, 4.8, 1.4],
       [5.4, 3.4, 1.5, 0.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.9, 3.1, 5.1, 2.3],
       [6.2, 2.2, 4.5, 1.5],
       [5.8, 2.7, 3.9, 1.2],
       [6.5, 3.2, 5.1, 2. ],
       [4.8, 3. , 1.4, 0.1],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.1, 3.8, 1.5, 0.3],
       [6.3, 3.3, 4.7, 1.6],
       [6.5, 3. , 5.8, 2.2],
       [5.6, 2.5, 3.9, 1.1],
       [5.7, 2.8, 4.5, 1.3],
       [6.4, 2.8, 5.6, 2.2],
       [4.7, 3.2, 1.6, 0.2],
       [6.1, 3. , 4.9, 1.8],
       [5. , 3.4, 1.6, 0.4],
       [6.4, 2.8, 5.6, 2.1],
       [7.9, 3.8, 6.4, 2. ],
       [6.7, 3. , 5.2, 2.3],
       [6.7, 2.5, 5.8, 1.8],
       [6.8

### Create an Estimator from the tf.keras model.

A `tf.keras.Model` can be trained with the `tf.estimator` API by converting the
model to an `tf.estimator.Estimator` object with
`tf.keras.estimator.model_to_estimator`.

In [6]:
model_dir = "../tboard/estimators_keras+estimator"
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=model_dir)

Train and evaluate the estimator.

In [7]:
keras_estimator.train(input_fn=input_fn, steps=25)

Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


Instructions for updating:
Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='../tboard/estimators_keras+estimator/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='../tboard/estimators_keras+estimator/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})


INFO:tensorflow:Warm-starting from: ../tboard/estimators_keras+estimator/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting from: ../tboard/estimators_keras+estimator/keras/keras_model.ckpt


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Warm-started 4 variables.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Create CheckpointSaverHook.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Saving checkpoints for 0 into ../tboard/estimators_keras+estimator/model.ckpt.


INFO:tensorflow:Saving checkpoints for 0 into ../tboard/estimators_keras+estimator/model.ckpt.


INFO:tensorflow:loss = 104.16339, step = 0


INFO:tensorflow:loss = 104.16339, step = 0


INFO:tensorflow:Saving checkpoints for 25 into ../tboard/estimators_keras+estimator/model.ckpt.


INFO:tensorflow:Saving checkpoints for 25 into ../tboard/estimators_keras+estimator/model.ckpt.


INFO:tensorflow:Loss for final step: 80.90711.


INFO:tensorflow:Loss for final step: 80.90711.


<tensorflow_estimator.python.estimator.estimator.EstimatorV2 at 0x7f11e9eb6400>

In [8]:
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))

INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Done calling model_fn.


INFO:tensorflow:Starting evaluation at 2019-11-19T22:05:02Z


INFO:tensorflow:Starting evaluation at 2019-11-19T22:05:02Z


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Graph was finalized.


INFO:tensorflow:Restoring parameters from ../tboard/estimators_keras+estimator/model.ckpt-25


INFO:tensorflow:Restoring parameters from ../tboard/estimators_keras+estimator/model.ckpt-25


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Done running local_init_op.


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [1/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [2/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [3/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [4/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [5/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [6/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [7/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [8/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [9/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Evaluation [10/10]


INFO:tensorflow:Finished evaluation at 2019-11-19-22:05:03


INFO:tensorflow:Finished evaluation at 2019-11-19-22:05:03


INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 97.676315


INFO:tensorflow:Saving dict for global step 25: global_step = 25, loss = 97.676315


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: ../tboard/estimators_keras+estimator/model.ckpt-25


INFO:tensorflow:Saving 'checkpoint_path' summary for global step 25: ../tboard/estimators_keras+estimator/model.ckpt-25


Eval result: {'global_step': 25, 'loss': 97.676315}
