In [8]:
import os
import numpy as np
from keras.datasets import mnist
import tensorflow as tf



tf.enable_eager_execution()

In [9]:
def load_mnist_data():   
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = np.reshape(x_train, [-1, 28,28,1])
    x_test = np.reshape(x_test, [-1, 28,28,1])
    train_data = {'images':x_train, 'labels':y_train}
    test_data = {'images':x_test, 'labels':y_test}
    return train_data, test_data

In [50]:
def export_tfrecords(data_set, name, directory):
    """Converts MNIST dataset to tfrecords.
    
    Args:
        data_set: Dictionary containing a numpy array of images and labels.
        name: Name given to the exported tfrecord dataset.
        directory: Directory that the tfrecord files will be saved in.
    """
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    images = data_set['images']
    labels = data_set['labels']
    num_examples = images.shape[0]  
    rows = images.shape[1]
    cols = images.shape[2]
    depth = images.shape[3]

    filename = os.path.join(directory, name + '.tfrecords')
    print('Writing', filename)
   
    writer = tf.python_io.TFRecordWriter(filename)
    for index in range(num_examples):
        image_raw = images[index].tostring()
        example = tf.train.Example(features=tf.train.Features(feature={
            'height': _int64_feature(rows),
            'width': _int64_feature(cols),
            'depth': _int64_feature(depth),
            'label': _int64_feature(int(labels[index])),
            'image_raw': _bytes_feature(image_raw)}))
        writer.write(example.SerializeToString())
    writer.close()

In [51]:
train_data, test_data = load_mnist_data()
export_tfrecords(train_data, "mnist_train","data")
export_tfrecords(test_data, "mnist_test","data")

Writing data/mnist_train.tfrecords
Writing data/mnist_test.tfrecords


In [45]:
# Create a description of the features.  
feature_description = {
    'height': tf.FixedLenFeature([], tf.int64, default_value=0),
    'width': tf.FixedLenFeature([], tf.int64, default_value=0),
    'depth': tf.FixedLenFeature([], tf.int64, default_value=0),
    'label': tf.FixedLenFeature([], tf.int64, default_value=0),
    'image_raw': tf.FixedLenFeature([], tf.string, default_value="")}

def _parse_function(example_proto):
  # Parse the input tf.Example proto using the dictionary above.
  example = tf.parse_single_example(example_proto, feature_description)
  example['image'] = tf.decode_raw(example['image_raw'], tf.uint8)
  example['image'] = tf.reshape(example['image'], [example['height'],example['width'],example['depth']])
  example.pop('image_raw', None)
  example.pop('height', None)
  example.pop('width', None)
  example.pop('depth', None)


  return example

In [46]:
def read_dataset(name, directory):
    filename = os.path.join(directory, name + '.tfrecords')
    raw_dataset = tf.data.TFRecordDataset(filename)
    parsed_dataset = raw_dataset.map(_parse_function)
    return parsed_dataset
    

In [47]:
data = read_dataset("mnist_train","data")

In [48]:
for raw in data.take(10):
  print(repr(raw))

{'label': <tf.Tensor: id=841, shape=(), dtype=int64, numpy=5>, 'image': <tf.Tensor: id=840, shape=(28, 28, 1), dtype=uint8, numpy=
array([[[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0],
        [  0]],

       [[  0],
        [ 