diff --git a/fuel/datasets/hdf5.py b/fuel/datasets/hdf5.py index 6991d0ca4..68830ff8a 100644 --- a/fuel/datasets/hdf5.py +++ b/fuel/datasets/hdf5.py @@ -488,13 +488,29 @@ def load(self): shapes = None source_shapes.append(shapes) self.data_sources = tuple(data_sources) + # TODO why is this a tuple, while `.sources` is a list? self.source_shapes = tuple(source_shapes) # This exists only for request sanity checking purposes. self.in_memory_subset = Subset( slice(None), len(self.data_sources[0])) else: + # I have a feeling we can probably get data shapes in a single way + # regardless of whether data is in memory or not + source_shapes = [] + for source_name, subset in zip(self.sources, self.subsets): + # Reuse this use case from a few lines up + if source_name in self.vlen_sources: + shapes = subset.index_within_subset( + handle[source_name].dims[0]['shapes'], + slice(None)) + else: + if self.user_given_subset != slice(None): + raise NotImplementedError('Not entirely sure how to handle user slices yet') + shapes = handle[source_name].shape + source_shapes.append(shapes) + + self.source_shapes = source_shapes self.data_sources = None - self.source_shapes = None self.in_memory_subset = None self._out_of_memory_close()