Skip to content

Commit

Permalink
Merge pull request #4 from jakirkham/add_blocks_split
Browse files Browse the repository at this point in the history
Add `blocks_split`
  • Loading branch information
jakirkham committed Oct 27, 2016
2 parents 1d7335b + 0483dfe commit 0a19d31
Show file tree
Hide file tree
Showing 2 changed files with 332 additions and 0 deletions.
181 changes: 181 additions & 0 deletions kenjutsu/kenjutsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
__date__ = "$Sep 08, 2016 15:46:46 EDT$"


import itertools
import operator
import math
import warnings


def reformat_slice(a_slice, a_length=None):
Expand Down Expand Up @@ -374,3 +377,181 @@ def len_slices(slices, lengths=None):
lens = tuple(lens)

return(lens)


def blocks_split(space_shape, block_shape, block_halo=None):
"""
Return a list of slicings to cut each block out of an array or other.
Takes an array with ``space_shape`` and ``block_shape`` for every
dimension and a ``block_halo`` to extend each block on each side. From
this, it can compute slicings to use for cutting each block out from
the original array, HDF5 dataset or other.
Note:
Blocks on the boundary that cannot extend the full range will
be truncated to the largest block that will fit. This will raise
a warning, which can be converted to an exception, if needed.
Args:
space_shape(tuple): Shape of array to slice
block_shape(tuple): Size of each block to take
block_halo(tuple): Halo to tack on to each block
Returns:
collections.Sequence of \
tuples of slices: Provides tuples of slices for \
retrieving blocks.
Examples:
>>> blocks_split(
... (2, 3,), (1, 1,), (1, 1,)
... ) #doctest: +NORMALIZE_WHITESPACE
([(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
(slice(1, 2, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(2, 3, 1))],
<BLANKLINE>
[(slice(0, 2, 1), slice(0, 2, 1)),
(slice(0, 2, 1), slice(0, 3, 1)),
(slice(0, 2, 1), slice(1, 3, 1)),
(slice(0, 2, 1), slice(0, 2, 1)),
(slice(0, 2, 1), slice(0, 3, 1)),
(slice(0, 2, 1), slice(1, 3, 1))],
<BLANKLINE>
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
(slice(1, 2, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(1, 2, 1))])
"""

try:
xrange
except NameError:
xrange = range

try:
from itertools import ifilter
from itertools import imap
from itertools import izip
except ImportError:
ifilter = filter
imap = map
izip = zip

if block_halo is not None:
assert (len(space_shape) == len(block_shape) == len(block_halo)), \
"The dimensions of `space_shape`, `block_shape`, and " + \
"`block_halo` should be the same."
else:
assert (len(space_shape) == len(block_shape)), \
"The dimensions of `space_shape` and `block_shape` " + \
"should be the same."

block_halo = tuple()
for i in xrange(len(space_shape)):
block_halo += (0,)

vec_add = lambda a, b: imap(operator.add, a, b)
vec_sub = lambda a, b: imap(operator.sub, a, b)

vec_mul = lambda a, b: imap(operator.mul, a, b)
vec_div = lambda a, b: imap(operator.div, a, b)
vec_mod = lambda a, b: imap(operator.mod, a, b)

vec_nonzero = lambda a: \
imap(lambda _: _[0], ifilter(lambda _: _[1], enumerate(a)))
vec_str = lambda a: imap(str, a)

vec_clip_floor = lambda a, a_min: \
imap(lambda _: _ if _ >= a_min else a_min, a)
vec_clip_ceil = lambda a, a_max: \
imap(lambda _: _ if _ <= a_max else a_max, a)
vec_clip = lambda a, a_min, a_max: \
vec_clip_ceil(vec_clip_floor(a, a_min), a_max)

uneven_block_division = tuple(vec_mod(space_shape, block_shape))

if any(uneven_block_division):
uneven_block_division_str = vec_nonzero(uneven_block_division)
uneven_block_division_str = vec_str(uneven_block_division_str)
uneven_block_division_str = ", ".join(uneven_block_division_str)

warnings.warn(
"Blocks will not evenly divide the array." +
" The following dimensions will be unevenly divided: %s." %
uneven_block_division_str,
RuntimeWarning
)

ranges_per_dim = []
haloed_ranges_per_dim = []
trimmed_halos_per_dim = []

for each_dim in xrange(len(space_shape)):
# Construct each block using the block size given. Allow to spill over.
if block_shape[each_dim] == -1:
block_shape = (block_shape[:each_dim] +
space_shape[each_dim:each_dim+1] +
block_shape[each_dim+1:])

# Generate block ranges.
a_range = []
for i in xrange(2):
offset = i * block_shape[each_dim]
this_range = xrange(
offset,
offset + space_shape[each_dim],
block_shape[each_dim]
)
a_range.append(list(this_range))

# Add the halo to each block on both sides.
a_range_haloed = []
for i in xrange(2):
sign = 2 * i - 1

haloed = vec_mul(
itertools.repeat(sign, len(a_range[i])),
itertools.repeat(block_halo[each_dim], len(a_range[i])),
)
haloed = vec_add(a_range[i], haloed)
haloed = vec_clip(haloed, 0, space_shape[each_dim])

a_range_haloed.append(list(haloed))

# Compute how to trim the halo off of each block.
# Clip each block to the boundaries.
a_trimmed_halo = []
for i in xrange(2):
trimmed = vec_sub(a_range[i], a_range_haloed[0])
a_trimmed_halo.append(list(trimmed))
a_range[i] = list(vec_clip(a_range[i], 0, space_shape[each_dim]))

# Convert all ranges to slices for easier use.
a_range = tuple(imap(slice, *a_range))
a_range_haloed = tuple(imap(slice, *a_range_haloed))
a_trimmed_halo = tuple(imap(slice, *a_trimmed_halo))

# Format all slices.
a_range = reformat_slices(a_range)
a_range_haloed = reformat_slices(a_range_haloed)
a_trimmed_halo = reformat_slices(a_trimmed_halo)

# Collect all blocks
ranges_per_dim.append(a_range)
haloed_ranges_per_dim.append(a_range_haloed)
trimmed_halos_per_dim.append(a_trimmed_halo)

# Take all combinations of all ranges to get blocks.
blocks = list(itertools.product(*ranges_per_dim))
haloed_blocks = list(itertools.product(*haloed_ranges_per_dim))
trimmed_halos = list(itertools.product(*trimmed_halos_per_dim))

return(blocks, haloed_blocks, trimmed_halos)
151 changes: 151 additions & 0 deletions tests/test_kenjutsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,156 @@ def load_tests(loader, tests, ignore):
return tests


class TestKenjutsu(unittest.TestCase):
def setUp(self):
pass


def test_blocks_split(self):
blocks = kenjutsu.blocks_split((2,), (1,))
self.assertEqual(
blocks,
([(slice(0, 1, 1),), (slice(1, 2, 1),)],
[(slice(0, 1, 1),), (slice(1, 2, 1),)],
[(slice(0, 1, 1),), (slice(0, 1, 1),)])
)

blocks = kenjutsu.blocks_split((2,), (-1,))
self.assertEqual(
blocks,
([(slice(0, 2, 1),)],
[(slice(0, 2, 1),)],
[(slice(0, 2, 1),)])
)

blocks = kenjutsu.blocks_split((2, 3,), (1, 1,))
self.assertEqual(
blocks,
([(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
(slice(1, 2, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(2, 3, 1))],
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
(slice(1, 2, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(2, 3, 1))],
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1))])
)

blocks = kenjutsu.blocks_split((2, 3,), (1, 1,), (0, 0))
self.assertEqual(
blocks,
([(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
(slice(1, 2, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(2, 3, 1))],
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
(slice(1, 2, 1), slice(1, 2, 1)),
(slice(1, 2, 1), slice(2, 3, 1))],
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(0, 1, 1))])
)

blocks = kenjutsu.blocks_split((10, 12,), (3, 2,), (4, 3,))
self.assertEqual(
blocks,
([(slice(0, 3, 1), slice(0, 2, 1)),
(slice(0, 3, 1), slice(2, 4, 1)),
(slice(0, 3, 1), slice(4, 6, 1)),
(slice(0, 3, 1), slice(6, 8, 1)),
(slice(0, 3, 1), slice(8, 10, 1)),
(slice(0, 3, 1), slice(10, 12, 1)),
(slice(3, 6, 1), slice(0, 2, 1)),
(slice(3, 6, 1), slice(2, 4, 1)),
(slice(3, 6, 1), slice(4, 6, 1)),
(slice(3, 6, 1), slice(6, 8, 1)),
(slice(3, 6, 1), slice(8, 10, 1)),
(slice(3, 6, 1), slice(10, 12, 1)),
(slice(6, 9, 1), slice(0, 2, 1)),
(slice(6, 9, 1), slice(2, 4, 1)),
(slice(6, 9, 1), slice(4, 6, 1)),
(slice(6, 9, 1), slice(6, 8, 1)),
(slice(6, 9, 1), slice(8, 10, 1)),
(slice(6, 9, 1), slice(10, 12, 1)),
(slice(9, 10, 1), slice(0, 2, 1)),
(slice(9, 10, 1), slice(2, 4, 1)),
(slice(9, 10, 1), slice(4, 6, 1)),
(slice(9, 10, 1), slice(6, 8, 1)),
(slice(9, 10, 1), slice(8, 10, 1)),
(slice(9, 10, 1), slice(10, 12, 1))],
[(slice(0, 7, 1), slice(0, 5, 1)),
(slice(0, 7, 1), slice(0, 7, 1)),
(slice(0, 7, 1), slice(1, 9, 1)),
(slice(0, 7, 1), slice(3, 11, 1)),
(slice(0, 7, 1), slice(5, 12, 1)),
(slice(0, 7, 1), slice(7, 12, 1)),
(slice(0, 10, 1), slice(0, 5, 1)),
(slice(0, 10, 1), slice(0, 7, 1)),
(slice(0, 10, 1), slice(1, 9, 1)),
(slice(0, 10, 1), slice(3, 11, 1)),
(slice(0, 10, 1), slice(5, 12, 1)),
(slice(0, 10, 1), slice(7, 12, 1)),
(slice(2, 10, 1), slice(0, 5, 1)),
(slice(2, 10, 1), slice(0, 7, 1)),
(slice(2, 10, 1), slice(1, 9, 1)),
(slice(2, 10, 1), slice(3, 11, 1)),
(slice(2, 10, 1), slice(5, 12, 1)),
(slice(2, 10, 1), slice(7, 12, 1)),
(slice(5, 10, 1), slice(0, 5, 1)),
(slice(5, 10, 1), slice(0, 7, 1)),
(slice(5, 10, 1), slice(1, 9, 1)),
(slice(5, 10, 1), slice(3, 11, 1)),
(slice(5, 10, 1), slice(5, 12, 1)),
(slice(5, 10, 1), slice(7, 12, 1))],
[(slice(0, 3, 1), slice(0, 2, 1)),
(slice(0, 3, 1), slice(2, 4, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(0, 2, 1)),
(slice(3, 6, 1), slice(2, 4, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(0, 2, 1)),
(slice(4, 7, 1), slice(2, 4, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(0, 2, 1)),
(slice(4, 7, 1), slice(2, 4, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1))])
)


def tearDown(self):
pass



if __name__ == '__main__':
sys.exit(unittest.main())

0 comments on commit 0a19d31

Please sign in to comment.