In [1]:
# Copyright 2019 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# MNIST TFRecord Creator

This notebook will create TFRecords from the MNIST dataset included with Keras and then upload them to S3 so that they can be used to train a TensorFlow model using Amazon Sagemaker

In [2]:
import os
import numpy as np
from keras.datasets import mnist
import tensorflow as tf
tf.enable_eager_execution()

Using TensorFlow backend.


## Sagemaker Specific Setup and Config

In [3]:
import sagemaker
bucket = sagemaker.Session().default_bucket() # Any S3 bucket can be specified but we're using the Sagemaker default bucket here.
prefix = 'sagemaker/ml-model-migration'
role = sagemaker.get_execution_role() 

In [3]:
def load_mnist_data():   
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = np.reshape(x_train, [-1, 28,28,1])
    x_test = np.reshape(x_test, [-1, 28,28,1])
    train_data = {'images':x_train, 'labels':y_train}
    test_data = {'images':x_test, 'labels':y_test}
    return train_data, test_data

In [4]:
def export_tfrecords(data_set, name, directory):
    """Converts MNIST dataset to tfrecords.
    
    Args:
        data_set: Dictionary containing a numpy array of images and labels.
        name: Name given to the exported tfrecord dataset.
        directory: Directory that the tfrecord files will be saved in.
    """
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    images = data_set['images']
    labels = data_set['labels']
    num_examples = images.shape[0]  
    rows = images.shape[1]
    cols = images.shape[2]
    depth = images.shape[3]

    filename = os.path.join(directory, name + '.tfrecords')
    print('Writing', filename)
   
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        image_raw = images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _int64_feature(int(labels[index])),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

In [6]:
train_data, test_data = load_mnist_data()
export_tfrecords(train_data, "mnist_train","data")
export_tfrecords(test_data, "mnist_test","data")

Writing data/mnist_train.tfrecords
Writing data/mnist_test.tfrecords


## Upload TFRecord Files to S3

In [13]:
inputs = sagemaker.Session().upload_data(path='data', bucket=bucket, key_prefix=prefix+'/data/mnist')
print(inputs)

s3://sagemaker-us-east-2-708267171719/sagemaker/ml-model-migration/data/mnist


## Example Code for parsing TFRecords back out

In [45]:
# Create a description of the features.  
feature_description = {
    'height': tf.FixedLenFeature([], tf.int64, default_value=0),
    'width': tf.FixedLenFeature([], tf.int64, default_value=0),
    'depth': tf.FixedLenFeature([], tf.int64, default_value=0),
    'label': tf.FixedLenFeature([], tf.int64, default_value=0),
    'image_raw': tf.FixedLenFeature([], tf.string, default_value="")}

def _parse_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  example = tf.parse_single_example(example_proto, feature_description)
  example['image'] = tf.decode_raw(example['image_raw'], tf.uint8)
  example['image'] = tf.reshape(example['image'], [example['height'],example['width'],example['depth']])
  example.pop('image_raw', None)
  example.pop('height', None)
  example.pop('width', None)
  example.pop('depth', None)


  return example

In [46]:
def read_dataset(name, directory):
    filename = os.path.join(directory, name + '.tfrecords')
    raw_dataset = tf.data.TFRecordDataset(filename)
    parsed_dataset = raw_dataset.map(_parse_function)
    return parsed_dataset
    

In [47]:
data = read_dataset("mnist_train","data")