diff --git a/docs/templates/datasets.md b/docs/templates/datasets.md index c9457edee5b..a2f1d1a1eab 100644 --- a/docs/templates/datasets.md +++ b/docs/templates/datasets.md @@ -146,6 +146,39 @@ from keras.datasets import mnist - __path__: if you do not have the index file locally (at `'~/.keras/datasets/' + path`), it will be downloaded to this location. +--- + +## Fashion-MNIST database of fashion articles + +Dataset of 60,000 28x28 grayscale images of 10 fashion categories, along with a test set of 10,000 images. This dataset can be used as a drop-in replacement for MNIST. The class labels are: + +| Label | Description | +| --- | --- | +| 0 | T-shirt/top | +| 1 | Trouser | +| 2 | Pullover | +| 3 | Dress | +| 4 | Coat | +| 5 | Sandal | +| 6 | Shirt | +| 7 | Sneaker | +| 8 | Bag | +| 9 | Ankle boot | + +### Usage: + +```python +from keras.datasets import fashion_mnist + +(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() +``` + +- __Returns:__ + - 2 tuples: + - __x_train, x_test__: uint8 array of grayscale image data with shape (num_samples, 28, 28). + - __y_train, y_test__: uint8 array of labels (integers in range 0-9) with shape (num_samples,). + + --- ## Boston housing price regression dataset diff --git a/keras/datasets/__init__.py b/keras/datasets/__init__.py index 912e8989a18..ae7c49a9d7b 100644 --- a/keras/datasets/__init__.py +++ b/keras/datasets/__init__.py @@ -6,3 +6,4 @@ from . import cifar10 from . import cifar100 from . import boston_housing +from . import fashion_mnist diff --git a/keras/datasets/fashion_mnist.py b/keras/datasets/fashion_mnist.py new file mode 100644 index 00000000000..b381d63f33d --- /dev/null +++ b/keras/datasets/fashion_mnist.py @@ -0,0 +1,37 @@ +import gzip +import os + +from ..utils.data_utils import get_file +import numpy as np + + +def load_data(): + """Loads the Fashion-MNIST dataset. + + # Returns + Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + """ + dirname = os.path.join('datasets', 'fashion-mnist') + base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' + files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', + 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'] + + paths = [] + for file in files: + paths.append(get_file(file, origin=base + file, cache_subdir=dirname)) + + with gzip.open(paths[0], 'rb') as lbpath: + y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[1], 'rb') as imgpath: + x_train = np.frombuffer(imgpath.read(), np.uint8, + offset=16).reshape(len(y_train), 28, 28) + + with gzip.open(paths[2], 'rb') as lbpath: + y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) + + with gzip.open(paths[3], 'rb') as imgpath: + x_test = np.frombuffer(imgpath.read(), np.uint8, + offset=16).reshape(len(y_test), 28, 28) + + return (x_train, y_train), (x_test, y_test) diff --git a/tests/keras/datasets/test_datasets.py b/tests/keras/datasets/test_datasets.py index 123b9f82068..6c4cfbd1e5c 100644 --- a/tests/keras/datasets/test_datasets.py +++ b/tests/keras/datasets/test_datasets.py @@ -8,6 +8,7 @@ from keras.datasets import imdb from keras.datasets import mnist from keras.datasets import boston_housing +from keras.datasets import fashion_mnist def test_cifar(): @@ -75,5 +76,15 @@ def test_boston_housing(): assert len(x_test) == len(y_test) +def test_fashion_mnist(): + # only run data download tests 20% of the time + # to speed up frequent testing + random.seed(time.time()) + if random.random() > 0.8: + (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data() + assert len(x_train) == len(y_train) == 60000 + assert len(x_test) == len(y_test) == 10000 + + if __name__ == '__main__': pytest.main([__file__])