Skip to content

Commit

Permalink
Make default transformers be class attributes, and other goodies
Browse files Browse the repository at this point in the history
  • Loading branch information
vdumoulin committed May 7, 2015
1 parent 35d4cb2 commit 79073b9
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 82 deletions.
9 changes: 6 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,20 @@ For example, after downloading the MNIST data to ``/home/your_data/mnist`` we
construct a handle to the data.

>>> from fuel.datasets import MNIST
>>> mnist = MNIST(which_set='train', flatten=('features',))
>>> mnist = MNIST(which_set='train')

In order to start reading the data, we need to initialize a *data stream*. A
data stream combines a dataset with a particular iteration scheme to read data
in a particular way. Let's say that in this case we want retrieve random
minibatches of size 512.

>>> from fuel.streams import DataStream
>>> from fuel.transformers import Flatten
>>> from fuel.schemes import ShuffledScheme
>>> stream = DataStream.default_stream(
... mnist, iteration_scheme=ShuffledScheme(mnist.num_examples, 512))
>>> stream = Flatten(
... DataStream.default_stream(
... mnist, iteration_scheme=ShuffledScheme(mnist.num_examples, 512)),
... which_sources=('features',))

Datasets can apply various default transformations on the original
data stream if their ``apply_default_transformers`` method is called. A
Expand Down
11 changes: 1 addition & 10 deletions fuel/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class Dataset(object):
"""
provides_sources = None
default_transformers = tuple()

def __init__(self, sources=None, axis_labels=None):
if not self.provides_sources:
Expand All @@ -66,16 +67,6 @@ def sources(self):
def sources(self, sources):
self._sources = sources

@property
def default_transformers(self):
if not hasattr(self, '_default_transformers'):
self._default_transformers = tuple()
return self._default_transformers

@default_transformers.setter
def default_transformers(self, default_transformers):
self._default_transformers = default_transformers

def apply_default_transformers(self, stream):
"""Applies default transformers to a stream.
Expand Down
7 changes: 2 additions & 5 deletions fuel/datasets/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from fuel import config
from fuel.datasets import H5PYDataset
from fuel.transformers import ForceFloatX, ScaleAndShift
from fuel.transformers.defaults import uint8_pixels_to_floatX


class CIFAR10(H5PYDataset):
Expand Down Expand Up @@ -30,15 +30,12 @@ class CIFAR10(H5PYDataset):
using the start and stop arguments.
"""
provides_sources = ('features', 'targets')
filename = 'cifar10.hdf5'
default_transformers = uint8_pixels_to_floatX(('features',))

def __init__(self, which_set, **kwargs):
kwargs.setdefault('load_in_memory', True)
super(CIFAR10, self).__init__(self.data_path, which_set, **kwargs)
self.default_transformers += (
(ScaleAndShift, [1 / 255.0, 0], {'which_sources': ('features',)}),
(ForceFloatX, [], {}))

@property
def data_path(self):
Expand Down
13 changes: 1 addition & 12 deletions fuel/datasets/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ class H5PYDataset(Dataset):
_file_handles = {}

def __init__(self, path, which_set, subset=None, load_in_memory=False,
flatten=None, driver=None, sort_indices=True, **kwargs):
driver=None, sort_indices=True, **kwargs):
self.path = path
self.driver = driver
self.sort_indices = sort_indices
Expand All @@ -140,21 +140,10 @@ def __init__(self, path, which_set, subset=None, load_in_memory=False,
raise ValueError("subset.step must be either 1 or None")
self._subset_template = subset
self.load_in_memory = load_in_memory
self.flatten = [] if flatten is None else flatten

kwargs.setdefault('axis_labels', self.load_axis_labels())
super(H5PYDataset, self).__init__(**kwargs)

if self.flatten:
self.default_transformers = (
(Flatten, [], {'which_sources': self.flatten}),)

for source in self.flatten:
if source not in self.provides_sources:
raise ValueError(
"trying to flatten source '{}' which is ".format(source) +
"not provided by the '{}' split".format(self.which_set))

@staticmethod
def create_split_array(split_dict):
"""Create a valid array for the `split` attribute of the root node.
Expand Down
6 changes: 2 additions & 4 deletions fuel/datasets/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from fuel import config
from fuel.datasets import H5PYDataset
from fuel.transformers import ForceFloatX, ScaleAndShift
from fuel.transformers.defaults import uint8_pixels_to_floatX


class MNIST(H5PYDataset):
Expand All @@ -29,13 +29,11 @@ class MNIST(H5PYDataset):
"""
filename = 'mnist.hdf5'
default_transformers = uint8_pixels_to_floatX(('features',))

def __init__(self, which_set, **kwargs):
kwargs.setdefault('load_in_memory', True)
super(MNIST, self).__init__(self.data_path, which_set, **kwargs)
self.default_transformers += (
(ScaleAndShift, [1 / 255.0, 0], {'which_sources': ('features',)}),
(ForceFloatX, [], {}))

@property
def data_path(self):
Expand Down
8 changes: 8 additions & 0 deletions fuel/transformers/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Commonly-used default transformers."""
from fuel.transformers import ScaleAndShift, Cast


def uint8_pixels_to_floatX(which_sources):
return (
(ScaleAndShift, [1 / 255.0, 0], {'which_sources': which_sources}),
(Cast, ['floatX'], {'which_sources': which_sources}))
48 changes: 0 additions & 48 deletions tests/test_hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,30 +207,6 @@ def test_h5py_dataset_in_memory():
os.remove('tmp.hdf5')


def test_h5py_flatten():
try:
h5file = h5py.File(name='tmp.hdf5', mode="w")
features = h5file.create_dataset(
'features', (10, 2, 3), dtype='float32')
features[...] = numpy.arange(60, dtype='float32').reshape((10, 2, 3))
targets = h5file.create_dataset('targets', (10,), dtype='uint8')
targets[...] = numpy.arange(10, dtype='uint8')
split_dict = {'train': {'features': (0, 10), 'targets': (0, 10)}}
h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict)
h5file.flush()
h5file.close()
dataset = H5PYDataset(which_set='train', path='tmp.hdf5',
load_in_memory=True, flatten=['features'])
stream = DataStream.default_stream(
dataset, iteration_scheme=SequentialScheme(10, 10))
assert_equal(next(stream.get_epoch_iterator())[0],
numpy.arange(60).reshape((10, 6)))
finally:
stream.close()
if os.path.exists('tmp.hdf5'):
os.remove('tmp.hdf5')


def test_h5py_dataset_out_of_memory_sorted_indices():
try:
h5file = h5py.File(name='tmp.hdf5', mode="w")
Expand Down Expand Up @@ -271,27 +247,3 @@ def test_h5py_dataset_out_of_memory_unsorted_indices():
dataset.close(handle)
if os.path.exists('tmp.hdf5'):
os.remove('tmp.hdf5')


def test_h5py_flatten_raises_error_on_invalid_name():
try:
h5file = h5py.File(name='tmp.hdf5', mode="w")
features = h5file.create_dataset(
'features', (10, 2, 3), dtype='float32')
features[...] = numpy.arange(60, dtype='float32').reshape((10, 2, 3))
targets = h5file.create_dataset('targets', (10,), dtype='uint8')
targets[...] = numpy.arange(10, dtype='uint8')
split_dict = {'train': {'features': (0, 10), 'targets': (0, 10)}}
h5file.attrs['split'] = H5PYDataset.create_split_array(split_dict)
h5file.flush()
h5file.close()
dataset = H5PYDataset(path='tmp.hdf5', load_in_memory=False,
which_set='train', flatten=['features'])
handle = dataset.open()
assert_raises(
ValueError, H5PYDataset, 'tmp.hdf5',
None, None, False, 'foo', None)
finally:
dataset.close(handle)
if os.path.exists('tmp.hdf5'):
os.remove('tmp.hdf5')

0 comments on commit 79073b9

Please sign in to comment.