## 通过 Keras 模型创建 Estimator
TensorFlow 完全支持 TensorFlow Estimator，可以从新的和现有的 tf.keras 模型创建 Estimator。本教程包含了该过程完整且最为简短的示例。

In [36]:
import tensorflow as tf
import numpy as np

In [37]:
# create model
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(16, activation='relu', input_shape=(4,)),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(3)
])

In [38]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              optimizer='adam')
model.summary()

Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_4 (Dense)              (None, 16)                80        
_________________________________________________________________
dropout_2 (Dropout)          (None, 16)                0         
_________________________________________________________________
dense_5 (Dense)              (None, 3)                 51        
Total params: 131
Trainable params: 131
Non-trainable params: 0
_________________________________________________________________


## 创建 Input Function 
estimator train 的时候需要传入： input_fn

In [39]:
# CSV文件中列的顺序
column_names = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'species']

feature_names = column_names[:-1]
label_name = column_names[-1]

print("Features: {}".format(feature_names))
print("Label: {}".format(label_name))

Features: ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
Label: species


In [44]:
def pack_features_vector(features, labels):
    """
    将一组特征列表（list）打包成一个特征向量
    """
    features = tf.stack(list(features.values()), axis=1)
    return features, labels

def input_fn():
    batch_size = 32
    train_dataset_fp = './datasets/iris_training.csv'
    train_dataset = tf.data.experimental.make_csv_dataset(train_dataset_fp,
                                                          batch_size,
                                                          column_names=column_names,
                                                          label_name=label_name,
                                                          num_epochs=1)
    train_dataset = train_dataset.map(pack_features_vector)
    return train_dataset.repeat()

In [46]:
input_fn()

<RepeatDataset shapes: ((None, 4), (None,)), types: (tf.float32, tf.int32)>

## keras model trans estimator

In [47]:
import pathlib
model_dir = './keras2est_model'
pathlib.Path(model_dir).mkdir(parents=True, exist_ok=True) 
keras_estimator = tf.keras.estimator.model_to_estimator(keras_model=model, model_dir=model_dir)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using the Keras model provided.
INFO:tensorflow:Using config: {'_model_dir': './keras2est_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [48]:
keras_estimator.train(input_fn=input_fn, steps=500)
eval_result = keras_estimator.evaluate(input_fn=input_fn, steps=10)
print('Eval result: {}'.format(eval_result))

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Warm-starting with WarmStartSettings: WarmStartSettings(ckpt_to_initialize_from='./keras2est_model/keras/keras_model.ckpt', vars_to_warm_start='.*', var_name_to_vocab_info={}, var_name_to_prev_var_name={})
INFO:tensorflow:Warm-starting from: ./keras2est_model/keras/keras_model.ckpt
INFO:tensorflow:Warm-starting variables only in TRAINABLE_VARIABLES.
INFO:tensorflow:Warm-started 4 variables.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into ./keras2est_model/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:loss = 1.4627235, step = 0
INFO:tensorflow:global_step/sec: 705.129
INFO:tensorflow:loss = 0.6269758, step = 100