Classify MNIST images with a DNNClassifier

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

  from ._conv import register_converters as _register_converters


In [2]:
# Define constants
image_dim = 28
num_labels = 10
batch_size = 80
num_steps = 8000
hidden_layers = [128, 32]

In [3]:
# Step 1: Create a function to parse MNIST data
def parser(record):
    features = tf.parse_single_example(record,
            features = {
                    'images': tf.FixedLenFeature([], tf.string),
                    'labels': tf.FixedLenFeature([], tf.int64),
                    })
    image = tf.decode_raw(features['images'], tf.uint8)
    image.set_shape([image_dim * image_dim])
    image = tf.cast(image, tf.float32) * (1.0/255) - 0.5
    label = features['labels']
    return image, label

In [4]:
# Step 2: Describe input data with a feature column
column = tf.feature_column.numeric_column('pixels', shape=[image_dim * image_dim])

# Step 3: Create a DNNClassifier with the feature column
dnn_class = tf.estimator.DNNClassifier(hidden_layers, [column],
        model_dir='dnn_output', n_classes=num_labels)

INFO:tensorflow:Using default config.
INFO:tensorflow:Using config: {'_model_dir': 'dnn_output', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': None, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x000002511F62A438>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}


In [5]:
# Step 4: Train the estimator
def train_func():
    dataset = tf.data.TFRecordDataset('images/mnist_train.tfrecords')
    dataset = dataset.map(parser).repeat().batch(batch_size)
    image, label = dataset.make_one_shot_iterator().get_next()
    return {'pixels': image}, label
dnn_class.train(train_func, steps=num_steps)

# Step 5: Test the estimator
def test_func():
    dataset = tf.data.TFRecordDataset('images/mnist_test.tfrecords')    
    dataset = dataset.map(parser).batch(batch_size)
    image, label = dataset.make_one_shot_iterator().get_next()
    return {'pixels': image}, label
metrics = dnn_class.evaluate(test_func)

INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Saving checkpoints for 0 into dnn_output\model.ckpt.
INFO:tensorflow:loss = 186.97464, step = 1
INFO:tensorflow:global_step/sec: 201.555
INFO:tensorflow:loss = 82.945694, step = 101 (0.497 sec)
INFO:tensorflow:global_step/sec: 227.243
INFO:tensorflow:loss = 49.25924, step = 201 (0.441 sec)
INFO:tensorflow:global_step/sec: 229.312
INFO:tensorflow:loss = 29.470917, step = 301 (0.435 sec)
INFO:tensorflow:global_step/sec: 231.939
INFO:tensorflow:loss = 32.09416, step = 401 (0.432 sec)
INFO:tensorflow:global_step/sec: 231.15
INFO:tensorflow:loss = 31.050756, step = 501 (0.431 sec)
INFO:tensorflow:global_step/sec: 228.748
INFO:tensorflow:loss = 29.216133, step = 601 (0.438 sec)
INFO:tensorflow:global_step/sec: 219.751
INFO:tensorflow:loss = 8.

INFO:tensorflow:Loss for final step: 12.084917.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2018-08-23-07:40:17
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from dnn_output\model.ckpt-8000
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Finished evaluation at 2018-08-23-07:40:18
INFO:tensorflow:Saving dict for global step 8000: accuracy = 0.96, average_loss = 0.12955147, global_step = 8000, loss = 10.364118
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8000: dnn_output\model.ckpt-8000


In [6]:
# Display metrics
for key, value in metrics.items():
    print(key, ': ', value)

accuracy :  0.96
average_loss :  0.12955147
loss :  10.364118
global_step :  8000
