Skip to content

Commit

Permalink
Fixing MNIST dataloader (#133)
Browse files Browse the repository at this point in the history
  • Loading branch information
tihbe committed Dec 3, 2021
1 parent ce75473 commit 1c83f1f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
3 changes: 0 additions & 3 deletions src/lava/utils/dataloader/mnist.npy

This file was deleted.

75 changes: 60 additions & 15 deletions src/lava/utils/dataloader/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,25 @@


class MnistDataset:
mirrors = [
"http://yann.lecun.com/exdb/mnist/",
"https://ossci-datasets.s3.amazonaws.com/mnist/",
"https://storage.googleapis.com/cvdf-datasets/mnist/",
]

files = [
"train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz",
]

def __init__(self, data_path=os.path.join(os.path.dirname(__file__),
'mnist.npy')):
"""data_path (str): Path to mnist.npy file containing the MNIST
dataset"""
if not os.path.exists(data_path):
# Download MNIST from internet and convert it to .npy
os.makedirs(os.path.join(os.path.dirname(__file__), 'temp'),
exist_ok=False)
exist_ok=True)
MnistDataset. \
download_mnist(path=os.path.join(
os.path.dirname(__file__),
Expand All @@ -24,16 +35,35 @@ def __init__(self, data_path=os.path.join(os.path.dirname(__file__),
# GUnzip, Parse and save MNIST data as .npy
MnistDataset.decompress_convert_save(
download_path=os.path.join(os.path.dirname(__file__), 'temp'),
save_path=os.path.dirname(data_path))
save_path=data_path)
self.data = np.load(data_path, allow_pickle=True)

# ToDo: Populate this method with a proper wget download from MNIST website
@staticmethod
def download_mnist(path=os.path.join(os.path.dirname(__file__), 'temp')):
pass
import urllib.request
import urllib.error

for file in MnistDataset.files:
err = None
for mirror in MnistDataset.mirrors:
try:
url = f"{mirror}{file}"
if url.lower().startswith("http"):
# Disabling security linter and using hardcoded
# URLs specified above
res = urllib.request.urlopen(url) # nosec
with open(os.path.join(path, file), "wb") as f:
f.write(res.read())
break
else:
raise "Url does not start with http"
except urllib.error.URLError as exception:
err = exception
continue
else:
print("Failed to download mnist dataset")
raise err

# ToDo: Populate this method with proper code to decompress, parse,
# and save MNIST as mnist.npy
@staticmethod
def decompress_convert_save(
download_path=os.path.join(os.path.dirname(__file__), 'temp'),
Expand All @@ -42,16 +72,31 @@ def decompress_convert_save(
download_path (str): path of downloaded raw MNIST dataset in IDX
format
save_path (str): path at which processed npy file will be saved
After loading data = np.load(), data is a np.array of np.arrays.
train_imgs = data[0][0]; shape = 60000 x 28 x 28
test_imgs = data[1][0]; shape = 10000 x 28 x 28
train_labels = data[0][1]; shape = 60000 x 1
test_labels = data[1][1]; shape = 10000 x 1
"""
# Gunzip, parse, and save as .npy
# Format of .npy:
# After loading data = np.load(), data is a np.array of np.arrays.
# train_imgs = data[0][0]; shape = 60000 x 28 x 28
# test_imgs = data[1][0]; shape = 10000 x 28 x 28
# train_labels = data[0][1]; shape = 60000 x 1
# test_labels = data[1][1]; shape = 10000 x 1
# save as 'mnist.npy' in save_path
pass

import gzip
arrays = []
for file in MnistDataset.files:
with gzip.open(os.path.join(download_path, file), "rb") as f:
if "images" in file:
arr = np.frombuffer(f.read(), np.uint8, offset=16)
arr = arr.reshape(-1, 28, 28)
else:
arr = np.frombuffer(f.read(), np.uint8, offset=8)
arrays.append(arr)

np.save(
save_path,
np.array(
[[arrays[0], arrays[1]], [arrays[2], arrays[3]]],
dtype="object"),
)

@property
def train_images(self):
Expand Down

0 comments on commit 1c83f1f

Please sign in to comment.