<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Setup" data-toc-modified-id="Setup-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Setup</a></span></li><li><span><a href="#Load-Data-Set" data-toc-modified-id="Load-Data-Set-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Load Data Set</a></span></li><li><span><a href="#Feature-columns" data-toc-modified-id="Feature-columns-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Feature columns</a></span></li><li><span><a href="#Model-&amp;-config" data-toc-modified-id="Model-&amp;-config-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Model &amp; config</a></span></li><li><span><a href="#Input-function" data-toc-modified-id="Input-function-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Input function</a></span></li><li><span><a href="#Train" data-toc-modified-id="Train-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Train</a></span></li><li><span><a href="#Eval" data-toc-modified-id="Eval-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Eval</a></span></li><li><span><a href="#Restore-Model" data-toc-modified-id="Restore-Model-8"><span class="toc-item-num">8&nbsp;&nbsp;</span>Restore Model</a></span></li></ul></div>

## Setup

In [1]:
import tensorflow as tf
import numpy as np

  from ._conv import register_converters as _register_converters


## Load Data Set

In [2]:
(train_x, train_y), (test_x, test_y) = tf.keras.datasets.mnist.load_data()

In [3]:
train_x.shape
train_y.shape
type(train_y[0])

(60000, 28, 28)

(60000,)

numpy.uint8

## Feature columns

In [4]:
feature_columns = [tf.feature_column.numeric_column("x", shape=[28, 28])]

## Model & config

调整日志打印的频率为 1000

In [5]:
config = tf.estimator.RunConfig(log_step_count_steps=1000)

这里可以看到能够修改的配置，比如 _keep_checkpoint_max用于设置保存的模型最大数量， _save_checkpoints_secs用于设置多久保存一次模型

In [27]:
vars(classifier.config)

{'_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec at 0x7f9fe86a9898>,
 '_distribute': None,
 '_evaluation_master': '',
 '_global_id_in_cluster': 0,
 '_is_chief': True,
 '_keep_checkpoint_every_n_hours': 10000,
 '_keep_checkpoint_max': 5,
 '_log_step_count_steps': 1000,
 '_master': '',
 '_model_dir': '/tmp/mnist_model',
 '_num_ps_replicas': 0,
 '_num_worker_replicas': 1,
 '_save_checkpoints_secs': 600,
 '_save_checkpoints_steps': None,
 '_save_summary_steps': 100,
 '_service': None,
 '_session_config': None,
 '_task_id': 0,
 '_task_type': 'worker',
 '_tf_random_seed': None}

In [6]:
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                        hidden_units=[256, 32],
                                        optimizer=tf.train.AdamOptimizer(1e-4),
                                        n_classes=10,
                                        dropout=0.1, 
                                        config=config,
                                        model_dir="/tmp/mnist_model")

INFO:tensorflow:Using config: {'_model_dir': '/tmp/mnist_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1000, '_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7fa0db5cb4e0>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


这里需要将 label 转换为 int32，否则会提示：TypeError: Value passed to parameter ‘labels’ has DataType uint8 not in list of allowed values: int32, int64

## Input function

In [12]:
train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": train_x}, y=train_y.astype(np.int32), num_epochs=None, batch_size=50, shuffle=True)
test_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": test_x}, y=test_y.astype(np.int32), num_epochs=1, shuffle=False)

## Train

In [8]:
classifier.train(input_fn=train_input_fn, steps=20000)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/mnist_model/model.ckpt-10003
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 10004 into /tmp/mnist_model/model.ckpt.
INFO:tensorflow:loss = 32.4552, step = 10004
INFO:tensorflow:global_step/sec: 245.468
INFO:tensorflow:loss = 14.070244, step = 11004 (4.076 sec)
INFO:tensorflow:global_step/sec: 240.218
INFO:tensorflow:loss = 9.689271, step = 12004 (4.164 sec)
INFO:tensorflow:global_step/sec: 240.201
INFO:tensorflow:loss = 15.497894, step = 13004 (4.163 sec)
INFO:tensorflow:global_step/sec: 233.643
INFO:tensorflow:loss = 14.839357, step = 14004 (4.281 sec)
INFO:tensorflow:global_step/sec: 238.44
INFO:tensorflow:loss = 3.9544375, step = 15004 (4.192 sec)
INFO:tensorflow:global_step/sec: 235.545
INFO:tensorflow:loss = 1

<tensorflow.python.estimator.canned.dnn.DNNClassifier at 0x7fa0db5cb320>

In [16]:
!ls -l {classifier.model_dir}

total 20544
-rw-r--r-- 1 root root     277 Apr  7 20:04 checkpoint
drwxr-xr-x 2 root root    4096 Apr  7 20:12 eval
-rw-r--r-- 1 root root 3010526 Apr  7 19:49 events.out.tfevents.1523101541.ubuntu
-rw-r--r-- 1 root root 1672160 Apr  7 20:01 events.out.tfevents.1523101753.ubuntu
-rw-r--r-- 1 root root 1207185 Apr  7 20:03 events.out.tfevents.1523102516.ubuntu
-rw-r--r-- 1 root root 1474407 Apr  7 20:04 events.out.tfevents.1523102586.ubuntu
-rw-r--r-- 1 root root  325958 Apr  7 20:03 graph.pbtxt
-rw-r--r-- 1 root root 2514184 Apr  7 19:49 model.ckpt-10001.data-00000-of-00001
-rw-r--r-- 1 root root     799 Apr  7 19:49 model.ckpt-10001.index
-rw-r--r-- 1 root root  143591 Apr  7 19:49 model.ckpt-10001.meta
-rw-r--r-- 1 root root 2514184 Apr  7 20:01 model.ckpt-10002.data-00000-of-00001
-rw-r--r-- 1 root root     799 Apr  7 20:01 model.ckpt-10002.index
-rw-r--r-- 1 root root  143591 Apr  7 20:01 model.ckpt-10002.meta
-rw-r--r-- 1 root root 2514184 Apr  7 20:01 model.ckpt-100

## Eval

In [13]:
classifier.evaluate(input_fn=test_input_fn)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-04-07-12:12:16
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/mnist_model/model.ckpt-30003
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-04-07-12:12:16
INFO:tensorflow:Saving dict for global step 30003: accuracy = 0.9673, average_loss = 0.16608413, global_step = 30003, loss = 21.023306


{'accuracy': 0.9673,
 'average_loss': 0.16608413,
 'global_step': 30003,
 'loss': 21.023306}

## Restore Model

指定了 model_dir 后，模型会自动读取相关参数，但是需要保证超参一直，例如，训练时设置hidden_units=[256, 32]，加载时设置hidden_units=[256, 33]，则会在 eval 时报错：
```
InvalidArgumentError (see above for traceback): tensor_name = dnn/hiddenlayer_1/bias; shape in shape_and_slice spec [33] does not match the shape stored in checkpoint: [32]
```

In [34]:
classifier2 = tf.estimator.DNNClassifier(feature_columns=feature_columns,
                                        hidden_units=[256, 32],
                                        optimizer=tf.train.AdamOptimizer(1e-4),
                                        n_classes=10,
                                        dropout=0.1, 
                                        config=config,
                                        model_dir="/tmp/mnist_model")

INFO:tensorflow:Using config: {'_model_dir': '/tmp/mnist_model', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 1000, '_distribute': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f9f88714940>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


默认读取最新的 checkpoint

In [39]:
classifier2.evaluate(input_fn=test_input_fn)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-04-07-12:34:17
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/mnist_model/model.ckpt-30003
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-04-07-12:34:17
INFO:tensorflow:Saving dict for global step 30003: accuracy = 0.9673, average_loss = 0.16608413, global_step = 30003, loss = 21.023306


{'accuracy': 0.9673,
 'average_loss': 0.16608413,
 'global_step': 30003,
 'loss': 21.023306}

也可以指定想要读取的 checkpoint

In [38]:
classifier2.evaluate(input_fn=test_input_fn, checkpoint_path="/tmp/mnist_model/model.ckpt-10001")

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-04-07-12:34:14
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmp/mnist_model/model.ckpt-10001
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-04-07-12:34:15
INFO:tensorflow:Saving dict for global step 10001: accuracy = 0.9116, average_loss = 0.35408103, global_step = 10001, loss = 44.820385


{'accuracy': 0.9116,
 'average_loss': 0.35408103,
 'global_step': 10001,
 'loss': 44.820385}