---
# MNIST Classification in TensorFlow 1.6

Building, training, and visualizing a convolutional neural network for MNIST digit classification

---

## Setup

In [1]:
import numpy as np
import tensorflow as tf

from typing import Dict

# Just because they're so verbose
TRAIN = tf.estimator.ModeKeys.TRAIN
PREDICT = tf.estimator.ModeKeys.PREDICT
EVAL = tf.estimator.ModeKeys.EVAL

tf.logging.set_verbosity(tf.logging.INFO)

## Model function definition

In [None]:
def model_fn(features: Dict[str, tf.Tensor],
             labels: Dict[str, tf.Tensor],
             mode: tf.estimator.ModeKeys) -> tf.estimator.EstimatorSpec:
    """Model function for an MNIST-classifying convolutional neural network.

    Args:
        features (Dict[str, tf.Tensor]): Dictionary of input Tensors.
        labels (Dict[str, tf.Tensor]): Dictionary of label Tensors.
        mode (tf.estimator.ModeKeys): Estimator mode.

    Returns:
        (tf.estimator.EstimatorSpec): MNIST CNN EstimatorSpec.

    """
    # Reshape input
    with tf.name_scope('input'):
        input_layer = tf.reshape(features["x"], [-1, 28, 28, 1])

    # First convolution + pooling block
    with tf.name_scope('conv-block1'):
        # Convolve
        with tf.name_scope('conv'):
            conv1 = tf.layers.conv2d(
                inputs=input_layer,
                filters=32,
                kernel_size=[5, 5],
                padding='valid',
                activation=tf.nn.relu)
        # Pool
        with tf.name_scope('pool'):
            pool1 = tf.layers.max_pooling2d(inputs=conv1,
                                            pool_size=[2, 2],
                                            strides=2)

    # Second convolution + pooling block
    with tf.name_scope('conv-block2'):
        # Convolve
        with tf.name_scope('conv'):
            conv2 = tf.layers.conv2d(
                inputs=pool1,
                filters=64,
                kernel_size=[5, 5],
                padding='valid',
                activation=tf.nn.relu)
        # Pool
        with tf.name_scope('pool'):
            pool2 = tf.layers.max_pooling2d(inputs=conv2,
                                            pool_size=[2, 2],
                                            strides=2)
    # Dense block
    with tf.name_scope('dense-block'):
        # Flatten the second pooling output
        with tf.name_scope('flatten'):
            pool2_flat = tf.contrib.layers.flatten(pool2)

        # Dense layer with 1024 hidden units
        with tf.name_scope('dense'):
            dense = tf.layers.dense(inputs=pool2_flat,
                                    units=1024,
                                    activation=tf.nn.relu)

        # Apply dropout during training
        with tf.name_scope('dropout'):
            dropout = tf.layers.dropout(inputs=dense,
                                        rate=0.4,
                                        training=mode == TRAIN)

    # Compute class logits
    with tf.name_scope('logits'):
        logits = tf.layers.dense(inputs=dropout, units=10)

    # Compute class predictions
    with tf.name_scope('classes'):
        classes = tf.argmax(input=logits, axis=1)

    # Compute class probabilities
    with tf.name_scope('probabilities'):
        probabilities = tf.nn.softmax(logits, name='softmax_tensor')


    # Both predictions (for PREDICT and EVAL modes)
    predictions = {
        'classes': classes,
        'probabilities': probabilities
    }

    # For a forward pass, no need to build optimization ops
    if mode == PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode,
                                          predictions=predictions)

    # Calculate loss for TRAIN and EVAL modes
    loss = tf.losses.sparse_softmax_cross_entropy(labels=labels,
                                                 logits=logits)
    
    # Configure the training op
    if mode == TRAIN: