## Creates tf records from mnist dataset
### TODO (davit): create this from feature store

In [1]:
import gzip
import os
import tempfile

import numpy
from six.moves import urllib

import tensorflow as tf

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log
1,application_1586507110993_0002,pyspark,idle,Link,Link


SparkSession available as 'spark'.


In [2]:
def _read32(bytestream):
  dt = numpy.dtype(numpy.uint32).newbyteorder('>')
  return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]

def extract_images(f):
  """
  Extract the images into a 4D uint8 numpy array.
  """
  print('Extracting', f.name)
  with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2051:
      raise ValueError('Invalid magic number %d in MNIST image file: %s' %
                       (magic, f.name))
    num_images = _read32(bytestream)
    rows = _read32(bytestream)
    cols = _read32(bytestream)
    buf = bytestream.read(rows * cols * num_images)
    data = numpy.frombuffer(buf, dtype=numpy.uint8)
    data = data.reshape(num_images, rows, cols, 1)
    return data

def extract_labels(f, one_hot=False, num_classes=10):
  """
  Extract the labels into a 1D uint8 numpy array.
  """
  print('Extracting', f.name)
  with gzip.GzipFile(fileobj=f) as bytestream:
    magic = _read32(bytestream)
    if magic != 2049:
      raise ValueError('Invalid magic number %d in MNIST label file: %s' %
                       (magic, f.name))
    num_items = _read32(bytestream)
    buf = bytestream.read(num_items)
    labels = numpy.frombuffer(buf, dtype=numpy.uint8)
    return labels

In [3]:
def load_dataset(images_file, labels_file):
  """Download and parse MNIST dataset."""

  #images_file = download(images_file)
  #labels_file = download(abels_file)

  with tf.io.gfile.GFile(images_file, 'rb') as f:
    images = extract_images(f)
    images = images.reshape(images.shape[0], images.shape[1] * images.shape[2])
    images = images.astype(numpy.float32)
    images = numpy.multiply(images, 1.0 / 255.0)
    
  with tf.io.gfile.GFile(labels_file, 'rb') as f:
    labels = extract_labels(f)

  return images, labels


In [4]:
import pydoop.hdfs as pydoop

a = pydoop.path.abspath("hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/t10k-images-idx3-ubyte.gz")
b = pydoop.path.abspath("hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/train-labels-idx1-ubyte.gz")

c = pydoop.path.abspath("hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/t10k-images-idx3-ubyte.gz")
d = pydoop.path.abspath("hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/train-labels-idx1-ubyte.gz")

train_images, train_labels = load_dataset(images_file=a,labels_file=b)
test_images, test_labels = load_dataset(c, d) 


Extracting hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/train-labels-idx1-ubyte.gz
Extracting hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/t10k-images-idx3-ubyte.gz
Extracting hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/MNIST_data/train-labels-idx1-ubyte.gz

In [5]:
from pyspark.sql.types import *
train_data = [(train_images[i].tolist(), int(train_labels[i])) for i in range(len(train_images))]
schema = StructType([StructField("image_raw", ArrayType(FloatType())),
                     StructField("label", LongType())])
train_df = spark.createDataFrame(train_data, schema)

path = "hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/train/df-mnist_train.tfrecord"
num_partition = 4
train_df.repartition(num_partition).write.format("tfrecords").mode("overwrite").save(path)

In [7]:
test_data = [(test_images[i].tolist(), int(test_labels[i])) for i in range(len(test_images))]
schema = StructType([StructField("image_raw", ArrayType(FloatType())),
                     StructField("label", LongType())])
test_df = spark.createDataFrame(test_data, schema)

path = "hdfs:///Projects/demo_deep_learning_admin000/TourData/mnist/validation/df-mnist_test.tfrecord"
num_partition = 4
test_df.repartition(num_partition).write.format("tfrecords").mode("overwrite").save(path)