Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Fix multi-channel hdf5 volume loader
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jan 8, 2019
1 parent f4b8fe8 commit 030c273
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
10 changes: 7 additions & 3 deletions inferno/io/volumetric/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=No
else:
raise NotImplementedError

# get the dataslice
if data_slice is None or isinstance(data_slice, (str, list)):
self.data_slice = vu.parse_data_slice(data_slice)
elif isinstance(data_slice, dict):
Expand All @@ -223,17 +224,20 @@ def __init__(self, path, path_in_h5_dataset=None, data_slice=None, transforms=No

slicing_config_for_name = pyu.get_config_for_name(slicing_config, name)

# adapt data-slice if this is a multi-channel volume (slice is not applied to channel dimension)
if self.data_slice is not None and slicing_config_for_name.get('is_multichannel', False):
self.data_slice = (slice(None),) + self.data_slice

assert 'window_size' in slicing_config_for_name
assert 'stride' in slicing_config_for_name

# Read in volume from file (can be hdf5, n5 or zarr)
dataslice_ = None if self.data_slice is None else tuple(self.data_slice)
if self.is_h5(self.path):
volume = iou.fromh5(self.path, self.path_in_h5_dataset,
dataslice=dataslice_)
dataslice=self.data_slice)
else:
volume = iou.fromz5(self.path, self.path_in_h5_dataset,
dataslice=dataslice_)
dataslice=self.data_slice)
# Initialize superclass with the volume
super(HDF5VolumeLoader, self).__init__(volume=volume, name=name, transforms=transforms,
**slicing_config_for_name)
Expand Down
5 changes: 2 additions & 3 deletions inferno/io/volumetric/volumetric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def parse_data_slice(data_slice):
return data_slice
elif isinstance(data_slice, (list, tuple)) and \
all([isinstance(_slice, slice) for _slice in data_slice]):
return list(data_slice)
return tuple(data_slice)
else:
assert isinstance(data_slice, str)
# Get rid of whitespace
Expand All @@ -162,5 +162,4 @@ def parse_data_slice(data_slice):
step = int(step) if step is not None and step != '' else None
# Build slices
slices.append(slice(start, stop, step))
# Done.
return slices
return tuple(slices)

0 comments on commit 030c273

Please sign in to comment.