Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fixed typo * added fashion-mnist dataset * added docs * pep8 * grammer * use offset instead of struct * reshape as in docs
- Loading branch information
Showing
4 changed files
with
82 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ | |
from . import cifar10 | ||
from . import cifar100 | ||
from . import boston_housing | ||
from . import fashion_mnist |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import gzip | ||
import os | ||
|
||
from ..utils.data_utils import get_file | ||
import numpy as np | ||
|
||
|
||
def load_data(): | ||
"""Loads the Fashion-MNIST dataset. | ||
# Returns | ||
Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. | ||
""" | ||
dirname = os.path.join('datasets', 'fashion-mnist') | ||
base = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' | ||
files = ['train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz', | ||
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'] | ||
|
||
paths = [] | ||
for file in files: | ||
paths.append(get_file(file, origin=base + file, cache_subdir=dirname)) | ||
|
||
with gzip.open(paths[0], 'rb') as lbpath: | ||
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8) | ||
|
||
with gzip.open(paths[1], 'rb') as imgpath: | ||
x_train = np.frombuffer(imgpath.read(), np.uint8, | ||
offset=16).reshape(len(y_train), 28, 28) | ||
|
||
with gzip.open(paths[2], 'rb') as lbpath: | ||
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8) | ||
|
||
with gzip.open(paths[3], 'rb') as imgpath: | ||
x_test = np.frombuffer(imgpath.read(), np.uint8, | ||
offset=16).reshape(len(y_test), 28, 28) | ||
|
||
return (x_train, y_train), (x_test, y_test) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
a379b42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I typed the following code:
import numpy as np
import tensorflow as tf
from tensorflow import keras
import keras
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import pandas as pd
from keras.models import load_model
from matplotlib.pyplot import imshow
from keras.datasets import fashion_mnist
train_data, test_data = datasets.fashion_mnist.load_data()
I obtained the following error:
Downloading data from http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Traceback (most recent call last):
File "C:\Users\DELL\Anaconda3\lib\site-packages\keras\utils\data_utils.py", line 222, in get_file
urlretrieve(origin, fpath, dl_progress)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 248, in urlretrieve
with contextlib.closing(urlopen(url, data)) as fp:
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 223, in urlopen
return opener.open(url, data, timeout)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 532, in open
response = meth(req, response)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 642, in http_response
'http', request, response, code, msg, hdrs)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 570, in error
return self._call_chain(*args)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 504, in _call_chain
result = func(*args)
File "C:\Users\DELL\Anaconda3\lib\urllib\request.py", line 650, in http_error_default
raise HTTPError(req.full_url, code, msg, hdrs, fp)
urllib.error.HTTPError: HTTP Error 407: Proxy Authentication Required
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:/Users/DELL/Dataload.py", line 28, in
train_data, test_data = datasets.fashion_mnist.load_data()
File "C:\Users\DELL\Anaconda3\lib\site-packages\keras\datasets\fashion_mnist.py", line 29, in load_data
cache_subdir=dirname))
File "C:\Users\DELL\Anaconda3\lib\site-packages\keras\utils\data_utils.py", line 224, in get_file
raise Exception(error_msg.format(origin, e.code, e.msg))
Exception: URL fetch failure on http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz: 407 -- Proxy Authentication Required
Any one can help me to solve this problem (how to load a dataset)
I have the same problem when loading: mnist, reuters, cifar10, cifar100, Bouston_housing