Skip to content

Commit

Permalink
Fashion mnist dataset (#7809)
Browse files Browse the repository at this point in the history
* fixed typo

* added fashion-mnist dataset

* added docs

* pep8

* grammer

* use offset instead of struct

* reshape as in docs
  • Loading branch information
kashif authored and fchollet committed Sep 6, 2017
1 parent 5625d70 commit a379b42
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 0 deletions.
33 changes: 33 additions & 0 deletions docs/templates/datasets.md
Expand Up @@ -146,6 +146,39 @@ from keras.datasets import mnist
- __path__: if you do not have the index file locally (at `'~/.keras/datasets/' + path`), it will be downloaded to this location.


---

## Fashion-MNIST database of fashion articles

Dataset of 60,000 28x28 grayscale images of 10 fashion categories, along with a test set of 10,000 images. This dataset can be used as a drop-in replacement for MNIST. The class labels are:

| Label | Description |
| --- | --- |
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |

### Usage:

```python
from keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
```

- __Returns:__
- 2 tuples:
- __x_train, x_test__: uint8 array of grayscale image data with shape (num_samples, 28, 28).
- __y_train, y_test__: uint8 array of labels (integers in range 0-9) with shape (num_samples,).


---

## Boston housing price regression dataset
Expand Down
1 change: 1 addition & 0 deletions keras/datasets/__init__.py
Expand Up @@ -6,3 +6,4 @@
from . import cifar10
from . import cifar100
from . import boston_housing
from . import fashion_mnist
37 changes: 37 additions & 0 deletions keras/datasets/fashion_mnist.py
@@ -0,0 +1,37 @@
import gzip
import os

from ..utils.data_utils import get_file
import numpy as np


def load_data():
"""Loads the Fashion-MNIST dataset.
# Returns
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
"""
dirname = os.path.join('datasets', 'fashion-mnist')
base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz']

paths = []
for file in files:
paths.append(get_file(file, origin=base + file, cache_subdir=dirname))

with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)

with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(imgpath.read(), np.uint8,
offset=16).reshape(len(y_train), 28, 28)

with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)

with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(imgpath.read(), np.uint8,
offset=16).reshape(len(y_test), 28, 28)

return (x_train, y_train), (x_test, y_test)
11 changes: 11 additions & 0 deletions tests/keras/datasets/test_datasets.py
Expand Up @@ -8,6 +8,7 @@
from keras.datasets import imdb
from keras.datasets import mnist
from keras.datasets import boston_housing
from keras.datasets import fashion_mnist


def test_cifar():
Expand Down Expand Up @@ -75,5 +76,15 @@ def test_boston_housing():
assert len(x_test) == len(y_test)


def test_fashion_mnist():
# only run data download tests 20% of the time
# to speed up frequent testing
random.seed(time.time())
if random.random() > 0.8:
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
assert len(x_train) == len(y_train) == 60000
assert len(x_test) == len(y_test) == 10000


if __name__ == '__main__':
pytest.main([__file__])

1 comment on commit a379b42

@kadsad28
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I typed the following code:

import numpy as np
import tensorflow as tf
from tensorflow import keras
import keras
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import pandas as pd
from keras.models import load_model
from matplotlib.pyplot import imshow

from keras.datasets import fashion_mnist

train_data, test_data = datasets.fashion_mnist.load_data()

I obtained the following error:

Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Traceback (most recent call last):
File "C:\Users\DELL\Anaconda3\lib\site-packages\keras\utils\data_utils.py", line 222, in get_file
urlretrieve(origin, fpath, dl_progress)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 248, in urlretrieve
with contextlib.closing(urlopen(url, data)) as fp:
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 223, in urlopen
return opener.open(url, data, timeout)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 532, in open
response = meth(req, response)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 642, in http_response
'http', request, response, code, msg, hdrs)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 570, in error
return self._call_chain(*args)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 504, in _call_chain
result = func(*args)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 650, in http_error_default
raise HTTPError(req.full_url, code, msg, hdrs, fp)
urllib.error.HTTPError: HTTP Error 407: Proxy Authentication Required

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "C:/Users/DELL/Dataload.py", line 28, in
train_data, test_data = datasets.fashion_mnist.load_data()
File "C:\Users\DELL\Anaconda3\lib\site-packages\keras\datasets\fashion_mnist.py", line 29, in load_data
cache_subdir=dirname))
File "C:\Users\DELL\Anaconda3\lib\site-packages\keras\utils\data_utils.py", line 224, in get_file
raise Exception(error_msg.format(origin, e.code, e.msg))
Exception: URL fetch failure on http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz: 407 -- Proxy Authentication Required

Any one can help me to solve this problem (how to load a dataset)
I have the same problem when loading: mnist, reuters, cifar10, cifar100, Bouston_housing

Please sign in to comment.