Skip to content

Commit

Permalink
Merge pull request mila-iqia#73 from vdumoulin/h5pydataset_determinism
Browse files Browse the repository at this point in the history
Fix nondeterministic behaviour in H5PYDataset
  • Loading branch information
dwf committed Apr 13, 2015
2 parents e544924 + 41390c0 commit 336952e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
12 changes: 12 additions & 0 deletions docs/h5py_dataset.rst
Expand Up @@ -180,6 +180,18 @@ in the split dictionary. If a particular split/source combination isn't present,
its ``available`` attribute is set to ``False``, which allows us to specify
only what's actually present in the HDF5 file we created.

.. tip::

By default, :class:`~.datasets.hdf5.H5PYDataset` sorts sources in
alphabetical order, and data requests are also returned in that order. If
``sources`` is passed as argument upon instantiation,
:class:`~.datasets.hdf5.H5PYDataset` will use the order of ``sources``
instead. This means that if you want to force a particular source order, you
can do so by explicitly passing the ``sources`` argument with the desired
ordering. For example, if your dataset has two sources named ``'features'``
and ``'targets'`` and you'd like the targets to be returned first, you need
to pass ``sources=('targets', 'features')`` as a constructor argument.

We flush, close the file and *voilà*!

>>> f.flush()
Expand Down
64 changes: 30 additions & 34 deletions fuel/datasets/hdf5.py
@@ -1,5 +1,5 @@
from itertools import product
from collections import defaultdict
from collections import defaultdict, OrderedDict

import h5py
import numpy
Expand Down Expand Up @@ -128,17 +128,17 @@ def __init__(self, path, which_set, subset=None, load_in_memory=False,
subset = subset if subset else slice(None)
if subset.step not in (1, None):
raise ValueError("subset.step must be either 1 or None")
self.subsets = [subset for source in self.provides_sources]
self.load_in_memory = load_in_memory
self.flatten = [] if flatten is None else flatten

super(H5PYDataset, self).__init__(**kwargs)

for source in self.flatten:
if source not in self.provides_sources:
raise ValueError(
"trying to flatten source '{}' which is ".format(source) +
"not provided by the '{}' split".format(self.which_set))

super(H5PYDataset, self).__init__(**kwargs)

self.subsets = [subset for source in self.sources]
self.load()

@staticmethod
Expand Down Expand Up @@ -166,7 +166,7 @@ def create_split_array(split_dict):
for val in split.values():
if len(val) == 3:
comment_len = max([comment_len, len(val[-1])])
sources = tuple(sources)
sources = sorted(list(sources))
source_len = max(len(source) for source in sources)

# Instantiate empty split array
Expand Down Expand Up @@ -204,15 +204,17 @@ def create_split_array(split_dict):

@staticmethod
def parse_split_array(split_array):
split_dict = defaultdict(dict)
split_dict = OrderedDict()
for row in split_array:
split, source, start, stop, available, comment = row
split = split.decode('utf8')
source = source.decode('utf8')
comment = comment.decode('utf8')
if available:
if split not in split_dict:
split_dict[split] = OrderedDict()
split_dict[split][source] = (start, stop, comment)
return dict(split_dict)
return split_dict

def _get_file_id(self):
file_id = [f for f in self.ref_counts.keys() if f.name == self.path]
Expand Down Expand Up @@ -241,28 +243,26 @@ def provides_sources(self):
def load(self):
handle = self._out_of_memory_open()
num_examples = None
for i, (source_name, data_source) in enumerate(handle.items()):
if source_name in self.provides_sources:
start, stop = self.split_dict[self.which_set][source_name][:2]
subset = self.subsets[i]
subset = slice(
start if subset.start is None else subset.start,
stop if subset.stop is None else subset.stop,
subset.step)
self.subsets[i] = subset
if num_examples is None:
num_examples = subset.stop - subset.start
if num_examples != subset.stop - subset.start:
raise ValueError("sources have different lengths")
for i, source_name in enumerate(self.sources):
start, stop = self.split_dict[self.which_set][source_name][:2]
subset = self.subsets[i]
subset = slice(
start if subset.start is None else subset.start,
stop if subset.stop is None else subset.stop,
subset.step)
self.subsets[i] = subset
if num_examples is None:
num_examples = subset.stop - subset.start
if num_examples != subset.stop - subset.start:
raise ValueError("sources have different lengths")
self.num_examples = num_examples
if self.load_in_memory:
data_sources = []
for i, (source_name, data_source) in enumerate(handle.items()):
if source_name in self.sources:
data = data_source[self.subsets[i]]
if source_name in self.flatten:
data = data.reshape((data.shape[0], -1))
data_sources.append(data)
for source_name, subset in zip(self.sources, self.subsets):
data = handle[source_name][subset]
if source_name in self.flatten:
data = data.reshape((data.shape[0], -1))
data_sources.append(data)
self.data_sources = data_sources
else:
self.data_sources = None
Expand Down Expand Up @@ -296,23 +296,19 @@ def get_data(self, state=None, request=None):
def _in_memory_get_data(self, state=None, request=None):
if state is not None or request is None:
raise ValueError
return self.filter_sources([data_source[request] for data_source
in self.data_sources])
return tuple(data_source[request] for data_source in self.data_sources)

def _out_of_memory_get_data(self, state=None, request=None):
rval = []
for i, (source_name, data_source) in enumerate(state.items()):
if source_name not in self.sources:
continue
subset = self.subsets[i]
for source_name, subset in zip(self.sources, self.subsets):
if isinstance(request, slice):
request = slice(request.start + subset.start,
request.stop + subset.start, request.step)
elif isinstance(request, list):
request = [index + subset.start for index in request]
else:
raise ValueError
data = data_source[request]
data = state[source_name][request]
if source_name in self.flatten:
data = data.reshape((data.shape[0], -1))
rval.append(data)
Expand Down

0 comments on commit 336952e

Please sign in to comment.