# MNIST distributed training with tensorflow

### Set up the environment

In [1]:
import os
import sagemaker
from sagemaker import get_execution_role

sagemaker_session = sagemaker.Session()

role = get_execution_role()

### Download the MNIST dataset

In [2]:
import utils
from tensorflow.contrib.learn.python.learn.datasets import mnist
import tensorflow as tf

data_sets = mnist.read_data_sets('data', dtype=tf.uint8, reshape=False, validation_size=5000)

utils.convert_to(data_sets.train, 'train', 'data')
utils.convert_to(data_sets.validation, 'validation', 'data')
utils.convert_to(data_sets.test, 'test', 'data')

  from ._conv import register_converters as _register_converters


Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting data/t10k-labels-idx1-ubyte.gz
('Writing', 'data/train.tfrecords')
('Writing', 'data/validation.tfrecords')
('Writing', 'data/test.tfrecords')


### Upload the data
We use the ```sagemaker.Session.upload_data``` function to upload our datasets to an S3 location. The return value inputs identifies the location -- we will use this later when we start the training job.

In [3]:
inputs = sagemaker_session.upload_data(path='data', key_prefix='data/DEMO-mnist')

INFO:sagemaker:Created S3 bucket: sagemaker-us-west-2-766924284651


# Construct a script for distributed training 

In [4]:
!cat 'mnist.py'

import os
import tensorflow as tf
from tensorflow.python.estimator.model_fn import ModeKeys as Modes

INPUT_TENSOR_NAME = 'inputs'
SIGNATURE_NAME = 'predictions'

LEARNING_RATE = 0.001


def model_fn(features, labels, mode, params):
    # Input Layer
    input_layer = tf.reshape(features[INPUT_TENSOR_NAME], [-1, 28, 28, 1])

    # Convolutional Layer #1
    conv1 = tf.layers.conv2d(
        inputs=input_layer,
        filters=32,
        kernel_size=[5, 5],
        padding='same',
        activation=tf.nn.relu)

    # Pooling Layer #1
    pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2)

    # Convolutional Layer #2 and Pooling Layer #2
    conv2 = tf.layers.conv2d(
        inputs=pool1,
        filters=64,
        kernel_size=[5, 5],
        padding='same',
        activation=tf.nn.relu)
    pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2)

    # Dense Layer
    pool2_flat = tf.reshape(pool2, [-1, 7

## Create a training job using the sagemaker.TensorFlow estimator

In [5]:
from sagemaker.tensorflow import TensorFlow

mnist_estimator = TensorFlow(entry_point='mnist.py',
                             role=role,
                             training_steps=1000, 
                             evaluation_steps=100,
                             train_instance_count=2,
                             train_instance_type='ml.c4.xlarge')

mnist_estimator.fit(inputs)

INFO:sagemaker:Creating training-job with name: sagemaker-tensorflow-2018-04-15-02-09-49-209


........................................................
[32m2018-04-15 02:14:27,077 INFO - root - running container entrypoint[0m
[32m2018-04-15 02:14:27,078 INFO - root - starting train task[0m
[32m2018-04-15 02:14:27,082 INFO - container_support.training - Training starting[0m
[31m2018-04-15 02:14:27,756 INFO - root - running container entrypoint[0m
[31m2018-04-15 02:14:27,756 INFO - root - starting train task[0m
[31m2018-04-15 02:14:27,761 INFO - container_support.training - Training starting[0m
  from ._conv import register_converters as _register_converters[0m
  from ._conv import register_converters as _register_converters[0m
[32m2018-04-15 02:14:28,978 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTP connection (1): 169.254.170.2[0m
[31m2018-04-15 02:14:29,745 INFO - botocore.vendored.requests.packages.urllib3.connectionpool - Starting new HTTP connection (1): 169.254.170.2[0m
[31m2018-04-15 02:14:29,961 INFO - botocore.ve

[32m2018-04-15 02:14:33.526753: I tensorflow/core/platform/s3/aws_logging.cc:54] Initializing config loader against fileName /root//.aws/config and using profilePrefix = 1[0m
[32m2018-04-15 02:14:33.526800: I tensorflow/core/platform/s3/aws_logging.cc:54] Initializing config loader against fileName /root//.aws/credentials and using profilePrefix = 0[0m
[32m2018-04-15 02:14:33.526817: I tensorflow/core/platform/s3/aws_logging.cc:54] Setting provider to read credentials from /root//.aws/credentials for credentials file and /root//.aws/config for the config file , for use with profile default[0m
[32m2018-04-15 02:14:33.526830: I tensorflow/core/platform/s3/aws_logging.cc:54] Creating HttpClient with max connections2 and scheme http[0m
[32m2018-04-15 02:14:33.526849: I tensorflow/core/platform/s3/aws_logging.cc:54] Initializing CurlHandleContainer with size 2[0m
[32m2018-04-15 02:14:33.526864: I tensorflow/core/platform/s3/aws_logging.cc:54] Creating TaskRole with default ECSCre

[31m2018-04-15 02:14:47.964233: E tensorflow/core/platform/s3/aws_logging.cc:60] No response body. Response code: 404[0m
[31m2018-04-15 02:14:47.964361: W tensorflow/core/platform/s3/aws_logging.cc:57] If the signature check failed. This could be because of a time skew. Attempting to adjust the signer.[0m
[31m2018-04-15 02:14:47.964576: I tensorflow/core/platform/s3/aws_logging.cc:54] Connection has been released. Continuing.[0m
[31m2018-04-15 02:14:48.039004: I tensorflow/core/platform/s3/aws_logging.cc:54] Connection has been released. Continuing.[0m
[31m2018-04-15 02:14:48.062412: E tensorflow/core/platform/s3/aws_logging.cc:60] No response body. Response code: 404[0m
[31m2018-04-15 02:14:48.062761: W tensorflow/core/platform/s3/aws_logging.cc:57] If the signature check failed. This could be because of a time skew. Attempting to adjust the signer.[0m
[31m2018-04-15 02:14:48.063076: I tensorflow/core/platform/s3/aws_logging.cc:54] Connection has been released. Continuing

[31m2018-04-15 02:15:29,832 INFO - tensorflow - loss = 0.051766116, step = 232 (38.247 sec)[0m
[31m2018-04-15 02:15:32,779 INFO - tensorflow - global_step/sec: 5.00828[0m
[31m2018-04-15 02:15:53,196 INFO - tensorflow - global_step/sec: 4.99576[0m
[32m2018-04-15 02:15:58,742 INFO - tensorflow - loss = 0.03206717, step = 377 (42.460 sec)[0m
[31m2018-04-15 02:16:07,545 INFO - tensorflow - loss = 0.054420207, step = 419 (37.712 sec)[0m
[31m2018-04-15 02:16:13,893 INFO - tensorflow - global_step/sec: 4.92834[0m
[31m2018-04-15 02:16:34,701 INFO - tensorflow - global_step/sec: 4.90195[0m
[31m2018-04-15 02:16:34.702351: I tensorflow/core/platform/s3/aws_logging.cc:54] Connection has been released. Continuing.[0m
[31m2018-04-15 02:16:34.798755: I tensorflow/core/platform/s3/aws_logging.cc:54] Connection has been released. Continuing.[0m
[31m2018-04-15 02:16:34.870621: I tensorflow/core/platform/s3/aws_logging.cc:54] Connection has been released. Continuing.[0m
[31m2018-04-1

[31m2018-04-15 02:18:09,945 INFO - tensorflow - Evaluation [10/100][0m
[31m2018-04-15 02:18:10,895 INFO - tensorflow - Evaluation [20/100][0m
[31m2018-04-15 02:18:11,767 INFO - tensorflow - Evaluation [30/100][0m
[31m2018-04-15 02:18:12,814 INFO - tensorflow - Evaluation [40/100][0m
[31m2018-04-15 02:18:13,710 INFO - tensorflow - Evaluation [50/100][0m
[31m2018-04-15 02:18:14,562 INFO - tensorflow - Evaluation [60/100][0m
[31m2018-04-15 02:18:15,572 INFO - tensorflow - Evaluation [70/100][0m
[31m2018-04-15 02:18:16,506 INFO - tensorflow - Evaluation [80/100][0m
[31m2018-04-15 02:18:17,422 INFO - tensorflow - Evaluation [90/100][0m
[31m2018-04-15 02:18:18,369 INFO - tensorflow - Evaluation [100/100][0m
[31m2018-04-15 02:18:18,400 INFO - tensorflow - Finished evaluation at 2018-04-15-02:18:18[0m
[31m2018-04-15 02:18:18,402 INFO - tensorflow - Saving dict for global step 1001: accuracy = 0.9861, global_step = 1001, loss = 0.045372747[0m
[31m2018-04-15 02:18:18.403

===== Job Complete =====
Billable seconds: 733


The **```fit```** method will create a training job in two **ml.c4.xlarge** instances. The logs above will show the instances doing training, evaluation, and incrementing the number of **training steps**. 

In the end of the training, the training job will generate a saved model for TF serving.

# Deploy the trained model to prepare for predictions

The deploy() method creates an endpoint which serves prediction requests in real-time.

In [6]:
mnist_predictor = mnist_estimator.deploy(initial_instance_count=1,
                                             instance_type='ml.m4.xlarge')

INFO:sagemaker:Creating model with name: sagemaker-tensorflow-2018-04-15-02-09-49-209
INFO:sagemaker:Creating endpoint with name sagemaker-tensorflow-2018-04-15-02-09-49-209


-------------------------------------------------------------------------!

# Invoking the endpoint

In [7]:
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)

for i in range(10):
    data = mnist.test.images[i].tolist()
    tensor_proto = tf.make_tensor_proto(values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32)
    predict_response = mnist_predictor.predict(tensor_proto)
    
    print("========================================")
    label = np.argmax(mnist.test.labels[i])
    print("label is {}".format(label))
    prediction = predict_response['outputs']['classes']['int64Val'][0]
    print("prediction is {}".format(prediction))

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz
label is 7
prediction is 7
label is 2
prediction is 2
label is 1
prediction is 1
label is 0
prediction is 0
label is 4
prediction is 4
label is 1
prediction is 1
label is 4
prediction is 4
label is 9
prediction is 9
label is 5
prediction is 5
label is 9
prediction is 9
