Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
- enable volume loader with multiple channels
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Aug 19, 2018
1 parent 5757090 commit 3649afd
Showing 1 changed file with 28 additions and 9 deletions.
37 changes: 28 additions & 9 deletions inferno/io/volumetric/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,58 @@
from . import volumetric_utils as vu
from ...utils import io_utils as iou
from ...utils import python_utils as pyu
from ...utils.exceptions import assert_, ShapeError


class VolumeLoader(SyncableDataset):
def __init__(self, volume, window_size, stride, downsampling_ratio=None, padding=None,
padding_mode='reflect', transforms=None, return_index_spec=False, name=None):
padding_mode='reflect', transforms=None, return_index_spec=False, name=None,
is_multichannel=False):
super(VolumeLoader, self).__init__()
# Validate volume
assert isinstance(volume, np.ndarray), str(type(volume))
# Validate window size and stride
assert len(window_size) == volume.ndim, "%i, %i" % (len(window_size), volume.ndim)
assert len(stride) == volume.ndim
if is_multichannel:
assert_(len(window_size) + 1 == volume.ndim, "%i, %i" % (len(window_size),
volume.ndim),
ShapeError)
assert_(len(stride) + 1 == volume.ndim, exception_type=ShapeError)
# TODO implemnent downsampling and padding for multi-channel volume
assert_(downsampling_ratio is None, exception_type=NotImplementedError)
assert_(padding is None, exception_type=NotImplementedError)
else:
assert_(len(window_size) == volume.ndim, "%i, %i" % (len(window_size),
volume.ndim),
ShapeError)
assert_(len(stride) == volume.ndim, exception_type=ShapeError)
# Validate transforms
assert transforms is None or callable(transforms)
assert_(transforms is None or callable(transforms))

self.name = name
self.return_index_spec = return_index_spec
self.volume = volume
self.window_size = window_size
self.stride = stride
self.padding_mode = padding_mode
self.is_multichannel = is_multichannel
self.transforms = transforms
# DataloaderIter should do the shuffling
self.shuffle = False

ndim = self.volume.ndim - 1 if is_multichannel else self.volume.ndim

if downsampling_ratio is None:
self.downsampling_ratio = [1] * self.volume.ndim
self.downsampling_ratio = [1] * ndim
elif isinstance(downsampling_ratio, int):
self.downsampling_ratio = [downsampling_ratio] * self.volume.ndim
elif isinstance(downsampling_ratio, (list, tuple)):
assert len(downsampling_ratio) == self.volume.ndim
assert_(len(downsampling_ratio) == self.volume.ndim, exception_type=ShapeError)
self.downsampling_ratio = list(downsampling_ratio)
else:
raise NotImplementedError

if padding is None:
self.padding = [[0, 0]] * self.volume.ndim
self.padding = [[0, 0]] * ndim
else:
self.padding = padding
self.pad_volume()
Expand All @@ -60,7 +76,8 @@ def pad_volume(self, padding=None):
return self.volume

def make_sliding_windows(self):
return list(vu.slidingwindowslices(shape=list(self.volume.shape),
shape = self.volume.shape[1:] if self.is_multichannel else self.volume.shape
return list(vu.slidingwindowslices(shape=list(shape),
window_size=self.window_size,
strides=self.stride,
shuffle=self.shuffle,
Expand All @@ -71,6 +88,8 @@ def __getitem__(self, index):
# Casting to int would allow index to be IndexSpec objects.
index = int(index)
slices = self.base_sequence[index]
if self.is_multichannel:
slices = (slice(None),) + tuple(slices)
sliced_volume = self.volume[tuple(slices)]
if self.transforms is None:
transformed = sliced_volume
Expand All @@ -83,7 +102,7 @@ def __getitem__(self, index):

def clone(self, volume=None, transforms=None, name=None):
# Make sure the volume shapes check out
assert volume.shape == self.volume.shape
assert_(volume.shape == self.volume.shape, exception_type=ShapeError)
# Make a new instance (without initializing)
new = type(self).__new__(type(self))
# Update dictionary to initialize
Expand Down

0 comments on commit 3649afd

Please sign in to comment.