In [None]:
import os
import gzip
from six.moves import urllib
import numpy as np
import tensorflow as tf

In [7]:
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
WORKING_DIRECTORY = 'data'
IMAGE_SIZE = 28
NUM_CHANNELS = 1
PIXEL_DEPTH = 255

In [9]:
def maybe_download(filename):
    if not tf.gfile.Exists(WORKING_DIRECTORY):
        tf.gfile.MakeDirs(WORKING_DIRECTORY)
    filepath = os.path.join(WORKING_DIRECTORY, filename)
    if tf.gfile.Exists(filepath):
        with tf.gfile.GFile(filepath) as f:
            print('%s has been downloaded already(%i Bytes), skip download.' % (filename, f.size()))
    else:
        filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
        with tf.gfile.GFile(filepath) as f:
            size = f.size()
        print('Successfully downloaded %s(%i Bytes)' % (filepath, size))
    return filepath

def extract_data(filename, num_images):
    print('Extracting %s' % filename)
    
    with gzip.open(filename) as bytestream:
        bytestream.read(16)
        buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * NUM_CHANNELS * num_images)
        data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
        data = (data - PIXEL_DEPTH / 2) / PIXEL_DEPTH
        data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
    return data

def extract_label(filename, num_labels):
    print('Extracting %s' % filename)
    with gzip.open(filename) as bs:
        bs.read(8)
        buf = bs.read(1 * num_labels)
        data = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
    return data
        
def main(_):
    train_data_filename = maybe_download('train-images-idx3-ubyte.gz')
    train_labels_filename = maybe_download('train-labels-idx1-ubyte.gz')
    test_data_filename = maybe_download('t10k-images-idx3-ubyte.gz')
    test_labels_filename = maybe_download('t10k-labels-idx1-ubyte.gz')

    train_data = extract_data(train_data_filename, 60000)
    train_labels = extract_label(train_labels_filename, 60000)
    test_data = extract_data(test_data_filename, 10000)
    test_labels = extract_label(test_labels_filename, 10000)
    
if __name__ == '__main__':
#     tf.app.run()
    main(None)


train-images-idx3-ubyte.gz has been downloaded already(9912422 Bytes), skip download.
train-labels-idx1-ubyte.gz has been downloaded already(28881 Bytes), skip download.
t10k-images-idx3-ubyte.gz has been downloaded already(1648877 Bytes), skip download.
t10k-labels-idx1-ubyte.gz has been downloaded already(4542 Bytes), skip download.
Extracting data/train-images-idx3-ubyte.gz
Extracting data/train-labels-idx1-ubyte.gz
Extracting data/t10k-images-idx3-ubyte.gz
Extracting data/t10k-labels-idx1-ubyte.gz
