Skip to content

Commit

Permalink
Merge pull request #4478 from akitotakeki/akitotakeki-add-extra-to-svhn
Browse files Browse the repository at this point in the history
add add_extra option to SVHN
  • Loading branch information
kmaehashi committed Mar 22, 2018
2 parents 009da60 + 2da7767 commit a5d1cff
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions chainer/datasets/svhn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def get_svhn(withlabel=True, scale=1., dtype=numpy.float32,
label_dtype=numpy.int32):
label_dtype=numpy.int32, add_extra=False):
"""Gets the SVHN dataset.
`The Street View House Numbers (SVHN) dataset <http://ufldl.stanford.edu/housenumbers/>`_
Expand All @@ -34,11 +34,13 @@ def get_svhn(withlabel=True, scale=1., dtype=numpy.float32,
scaled to the interval ``[0, 1]``.
dtype: Data type of resulting image arrays.
label_dtype: Data type of the labels.
add_extra: Use extra training set.
Returns:
A tuple of two datasets. If ``withlabel`` is ``True``, both datasets
are :class:`~chainer.datasets.TupleDataset` instances. Otherwise, both
datasets are arrays of images.
If ``add_extra`` is ``False``, a tuple of two datasets (train and test). Otherwise,
a tuple of three datasets (train, test, and extra).
If ``withlabel`` is ``True``, all datasets are :class:`~chainer.datasets.
TupleDataset` instances. Otherwise, both datasets are arrays of images.
""" # NOQA
if not _scipy_available:
Expand All @@ -50,7 +52,13 @@ def get_svhn(withlabel=True, scale=1., dtype=numpy.float32,
test_raw = _retrieve_svhn_test()
test = _preprocess_svhn(test_raw, withlabel, scale, dtype,
label_dtype)
return train, test
if add_extra:
extra_raw = _retrieve_svhn_extra()
extra = _preprocess_svhn(extra_raw, withlabel, scale, dtype,
label_dtype)
return train, test, extra
else:
return train, test


def _preprocess_svhn(raw, withlabel, scale, image_dtype, label_dtype):
Expand Down Expand Up @@ -79,6 +87,11 @@ def _retrieve_svhn_test():
return _retrieve_svhn("test.npz", url)


def _retrieve_svhn_extra():
url = "http://ufldl.stanford.edu/housenumbers/extra_32x32.mat"
return _retrieve_svhn("extra.npz", url)


def _retrieve_svhn(name, url):
root = download.get_dataset_directory('pfnet/chainer/svhn')
path = os.path.join(root, name)
Expand Down

0 comments on commit a5d1cff

Please sign in to comment.