Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #91 from inferno-pytorch/fix-slidingwindow
Browse files Browse the repository at this point in the history
Fix slidingwindow
  • Loading branch information
nasimrahaman committed Oct 18, 2017
2 parents 261abcf + 51f4ea9 commit aa26432
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 9 deletions.
8 changes: 4 additions & 4 deletions inferno/io/volumetric/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ def pad_volume(self, padding=None):

def make_sliding_windows(self):
return list(vu.slidingwindowslices(shape=list(self.volume.shape),
nhoodsize=self.window_size,
stride=self.stride,
window_size=self.window_size,
strides=self.stride,
shuffle=self.shuffle,
ds=self.downsampling_ratio))
add_overhanging=True))

def __getitem__(self, index):
# Casting to int would allow index to be IndexSpec objects.
Expand Down Expand Up @@ -191,4 +191,4 @@ def __init__(self, path, data_slice=None, transforms=None, name=None, **slicing_
volume = volume[self.data_slice] if self.data_slice is not None else volume
# Initialize superclass with the volume
super(TIFVolumeLoader, self).__init__(volume=volume, transforms=transforms,
**slicing_config)
**slicing_config)
60 changes: 56 additions & 4 deletions inferno/io/volumetric/volumetric_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,63 @@
import itertools as it


def slidingwindowslices(shape, window_size, strides,
ds=1, shuffle=True, rngseed=None,
dataslice=None, add_overhanging=True):
# only support lists or tuples for shape, window_size and strides
assert isinstance(shape, (list, tuple))
assert isinstance(window_size, (list, tuple))
assert isinstance(strides, (list, tuple))

dim = len(shape)
assert len(window_size) == dim
assert len(strides) == dim

# check for downsampling
assert isinstance(ds, (list, tuple, int))
if isinstance(ds, int):
ds = [ds] * dim
assert len(ds) == dim

# Seed RNG if a seed is provided
if rngseed is not None:
random.seed(rngseed)

# sliding windows in one dimenstion
def dimension_window(start, stop, wsize, stride, dimsize, ds_dim):
starts = range(start, stop + 1, stride)
slices = [slice(st, st + wsize, ds_dim) for st in starts if st + wsize <= dimsize]

# add an overhanging window at the end if the windoes
# do not fit and `add_overhanging`
if slices[-1].stop != dimsize and add_overhanging:
slices.append(slice(dimsize - wsize, dimsize, ds_dim))

if shuffle:
random.shuffle(slices)
return slices

# determine adjusted start and stop coordinates if we have a dataslice
# otherwise predict the whole volume
if dataslice is not None:
assert len(dataslice) == dim, "Dataslice must be a tuple with len = data dimension."
starts = [sl.start for sl in dataslice]
stops = [sl.stop - wsize for sl, wsize in zip(dataslice, window_size)]
else:
starts = dim * [0]
stops = [dimsize - wsize for dimsize, wsize in zip(shape, window_size)]

nslices = [dimension_window(start, stop, wsize, stride, dimsize, ds_dim)
for start, stop, wsize, stride, dimsize, ds_dim
in zip(starts, stops, window_size, strides, shape, ds)]
return it.product(*nslices)


# This code is legacy af, don't judge
# Define a sliding window iterator (this time, more readable than a wannabe one-liner)
def slidingwindowslices(shape, nhoodsize, stride=1, ds=1, window=None, ignoreborder=True,
shuffle=True, rngseed=None,
startmins=None, startmaxs=None, dataslice=None):
def slidingwindowslices_depr(shape, nhoodsize, stride=1, ds=1, window=None, ignoreborder=True,
shuffle=True, rngseed=None,
startmins=None, startmaxs=None, dataslice=None):
"""
Returns a generator yielding (shuffled) sliding window slice objects.
:type shape: int or list of int
Expand Down Expand Up @@ -108,4 +160,4 @@ def parse_data_slice(data_slice):
# Build slices
slices.append(slice(start, stop, step))
# Done.
return slices
return slices
2 changes: 1 addition & 1 deletion inferno/trainers/callbacks/logging/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@ def set_log_directory(self, log_directory):
assert isinstance(log_directory, str)
if not os.path.isdir(log_directory):
assert not os.path.exists(log_directory)
os.mkdir(log_directory)
os.makedirs(log_directory)
self._log_directory = log_directory
return self

0 comments on commit aa26432

Please sign in to comment.