diff --git a/chainer/datasets/__init__.py b/chainer/datasets/__init__.py index b9e2f431f6a1..292f991386a1 100644 --- a/chainer/datasets/__init__.py +++ b/chainer/datasets/__init__.py @@ -4,6 +4,7 @@ from chainer.datasets import mnist # NOQA from chainer.datasets import ptb # NOQA from chainer.datasets import sub_dataset # NOQA +from chainer.datasets import svhn # NOQA from chainer.datasets import transform_dataset # NOQA from chainer.datasets import tuple_dataset # NOQA @@ -25,5 +26,6 @@ from chainer.datasets.sub_dataset import split_dataset_n_random # NOQA from chainer.datasets.sub_dataset import split_dataset_random # NOQA from chainer.datasets.sub_dataset import SubDataset # NOQA +from chainer.datasets.svhn import get_svhn # NOQA from chainer.datasets.transform_dataset import TransformDataset # NOQA from chainer.datasets.tuple_dataset import TupleDataset # NOQA diff --git a/chainer/datasets/svhn.py b/chainer/datasets/svhn.py new file mode 100644 index 000000000000..d27c34e009e8 --- /dev/null +++ b/chainer/datasets/svhn.py @@ -0,0 +1,90 @@ +import os + +import numpy +try: + from scipy import io + _scipy_available = True +except ImportError: + _scipy_available = False + +from chainer.dataset import download +from chainer.datasets import tuple_dataset + + +def get_svhn(withlabel=True, scale=1., dtype=numpy.float32, + label_dtype=numpy.int32): + """Gets the SVHN dataset. + + `SVHN ` is a dataset + similar to MNIST but composed of cropped images of house numbers. + The functionality is identical to the MNIST dataset, + with the exception that there is no ``ndim`` argument. + + Args: + withlabel (bool): If ``True``, it returns datasets with labels. In this + case, each example is a tuple of an image and a label. Otherwise, + the datasets only contain images. + scale (float): Pixel value scale. If it is 1 (default), pixels are + scaled to the interval ``[0, 1]``. + dtype: Data type of resulting image arrays. + label_dtype: Data type of the labels. + + Returns: + A tuple of two datasets. If ``withlabel`` is ``True``, both datasets + are :class:`~chainer.datasets.TupleDataset` instances. Otherwise, both + datasets are arrays of images. + + """ + if not _scipy_available: + raise RuntimeError('scipy is not available') + + train_raw = _retrieve_svhn_training() + train = _preprocess_svhn(train_raw, withlabel, scale, dtype, + label_dtype) + test_raw = _retrieve_svhn_test() + test = _preprocess_svhn(test_raw, withlabel, scale, dtype, + label_dtype) + return train, test + + +def _preprocess_svhn(raw, withlabel, scale, image_dtype, label_dtype): + images = raw["x"].transpose(3, 2, 0, 1) + images = images.astype(image_dtype) + images *= scale / 255. + + labels = raw["y"].astype(label_dtype).flatten() + # labels go from 1-10, with the digit "0" having label 10. + # Set "0" to be label 0 to restore expected ordering + labels[labels == 10] = 0 + + if withlabel: + return tuple_dataset.TupleDataset(images, labels) + else: + return images + + +def _retrieve_svhn_training(): + url = "http://ufldl.stanford.edu/housenumbers/train_32x32.mat" + return _retrieve_svhn("train.npz", url) + + +def _retrieve_svhn_test(): + url = "http://ufldl.stanford.edu/housenumbers/test_32x32.mat" + return _retrieve_svhn("test.npz", url) + + +def _retrieve_svhn(name, url): + root = download.get_dataset_directory('pfnet/chainer/svhn') + path = os.path.join(root, name) + return download.cache_or_load_file( + path, lambda path: _make_npz(path, url), numpy.load) + + +def _make_npz(path, url): + _path = download.cached_download(url) + raw = io.loadmat(_path) + images = raw["X"].astype(numpy.uint8) + labels = raw["y"].astype(numpy.uint8) + + numpy.savez_compressed(path, x=images, y=labels) + return {'x': images, 'y': labels}