[toc]

# Tensorflow estimator earlystop

tensorflow estimator 中的 earlystop 是通过 training_hook 实现的

In [2]:
!pip install requests



In [91]:
import requests
import csv

TRAIN_URL = "http://download.tensorflow.org/data/iris_training.csv"
TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

def download_file(url, fname):
    r = requests.get(url)
    with open(fname, 'w') as f:
        f.write(r.text)
        
download_file(TRAIN_URL, 'train.txt')
download_file(TEST_URL, 'test.txt')

In [93]:
import tensorflow as tf
import functools

def create_model(x, hidden_units, n_classes):
    for units in hidden_units:
        x = tf.layers.dense(x, units=units, activation='relu')
    logits = tf.layers.dense(x, n_classes)
    return logits


def model_fn(features, labels, mode, params):
    x, y = features
    hidden_units = params['hidden_units']
    n_classes = params['n_class']
    logits = create_model(x, hidden_units, n_classes)

    if mode == tf.estimator.ModeKeys.TRAIN:
        train_op = tf.train.AdamOptimizer().minimize(
            loss, global_step=tf.train.get_or_create_global_step()
        )
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
    elif mode == tf.estimator.ModeKeys.EVAL:
        metrics = {
            "acc": tf.metrics.accuracy(labels=y, predictions=logits)
        }
        return tf.estimator.EstimatorSpec(mode, loss=loss, metrics=metrics)
    else:
        raise Exception("Only support TRAIN and EVAL mode")

In [94]:
def data_geneartor(fname):
    with open(fname) as f:
        reader = csv.reader(f)
        next(reader)
        for row in reader:
            yield [row[:4]], [row[-1]]

In [97]:
def input_fn(fname, 
             shuffle_and_repeat=True, 
             params=None):
    params = params if params is not None else {}
    dataset = tf.data.Dataset.from_generator(
        functools.partial(data_geneartor, fname),
        (tf.float32, tf.int32)
    )

    if shuffle_and_repeat:
        dataset = dataset.shuffle(params.get('buffer', 64)).repeat(params.get('epochs', 32))

    dataset = dataset.batch(params.get('batch_size', 32)).prefetch(1)
    return dataset


train_input_fn = functools.partial(input_fn, 'train.txt')
eval_input_fn = functools.partial(input_fn, 'test.txt')

In [98]:
params = {'hidden_units': [50, 20, 10], 'n_classes': 3}
cfg = tf.estimator.RunConfig(save_checkpoints_secs=120)
estimator = tf.estimator.Estimator(model_fn, 'output_dir', cfg, params)
hook = tf.estimator.experimental.stop_if_no_increase_hook(
    estimator, 'acc', max_steps_without_increase=50
)
train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, hooks=[hook])
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, throttle_secs=120)
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

INFO:tensorflow:Using config: {'_model_dir': 'output_dir', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 120, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fece9f710d0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation lo

OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function.

# References
1. [tf_ner/main.py at master · guillaumegenthial/tf_ner](https://github.com/guillaumegenthial/tf_ner/blob/master/models/lstm_crf/main.py#L36)