Skip to content

Commit

Permalink
support passing in a source url to the mnist read_data_sets function,…
Browse files Browse the repository at this point in the history
… to make it easier to use 'fashion mnist' etc. (tensorflow#12983)
  • Loading branch information
amygdala authored and drpngx committed Sep 12, 2017
1 parent 7951757 commit 4af9be9
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions tensorflow/contrib/learn/python/learn/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tensorflow.python.platform import gfile

# CVDF mirror of http://yann.lecun.com/exdb/mnist/
SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'
DEFAULT_SOURCE_URL = 'https://storage.googleapis.com/cvdf-datasets/mnist/'


def _read32(bytestream):
Expand Down Expand Up @@ -215,7 +215,8 @@ def read_data_sets(train_dir,
dtype=dtypes.float32,
reshape=True,
validation_size=5000,
seed=None):
seed=None,
source_url=DEFAULT_SOURCE_URL):
if fake_data:

def fake():
Expand All @@ -227,28 +228,31 @@ def fake():
test = fake()
return base.Datasets(train=train, validation=validation, test=test)

if not source_url: # empty string check
source_url = DEFAULT_SOURCE_URL

TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'

local_file = base.maybe_download(TRAIN_IMAGES, train_dir,
SOURCE_URL + TRAIN_IMAGES)
source_url + TRAIN_IMAGES)
with gfile.Open(local_file, 'rb') as f:
train_images = extract_images(f)

local_file = base.maybe_download(TRAIN_LABELS, train_dir,
SOURCE_URL + TRAIN_LABELS)
source_url + TRAIN_LABELS)
with gfile.Open(local_file, 'rb') as f:
train_labels = extract_labels(f, one_hot=one_hot)

local_file = base.maybe_download(TEST_IMAGES, train_dir,
SOURCE_URL + TEST_IMAGES)
source_url + TEST_IMAGES)
with gfile.Open(local_file, 'rb') as f:
test_images = extract_images(f)

local_file = base.maybe_download(TEST_LABELS, train_dir,
SOURCE_URL + TEST_LABELS)
source_url + TEST_LABELS)
with gfile.Open(local_file, 'rb') as f:
test_labels = extract_labels(f, one_hot=one_hot)

Expand All @@ -262,13 +266,13 @@ def fake():
train_images = train_images[validation_size:]
train_labels = train_labels[validation_size:]


options = dict(dtype=dtype, reshape=reshape, seed=seed)

train = DataSet(train_images, train_labels, **options)
validation = DataSet(validation_images, validation_labels, **options)
test = DataSet(test_images, test_labels, **options)

return base.Datasets(train=train, validation=validation, test=test)


Expand Down

0 comments on commit 4af9be9

Please sign in to comment.