Skip to content

Commit

Permalink
split_blocks: Add optional block index.
Browse files Browse the repository at this point in the history
In some cases, it is handy to have a positional index for each block to
know where blocks are relative to each other. Having these indices can
be helpful when stitching the results back in some cases. This change
adds a flag so that we can optionally get this index as the first
argument. Also add a warning to indicate that this flag's default
behavior is going to change to `True` in the next minor release.
  • Loading branch information
jakirkham committed Mar 3, 2017
1 parent 5e3f02e commit 9b83e47
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 4 deletions.
38 changes: 34 additions & 4 deletions kenjutsu/blocks.py
Expand Up @@ -12,7 +12,7 @@
import kenjutsu.format


def split_blocks(space_shape, block_shape, block_halo=None):
def split_blocks(space_shape, block_shape, block_halo=None, index=None):
"""
Return a list of slicings to cut each block out of an array or other.
Expand All @@ -30,6 +30,8 @@ def split_blocks(space_shape, block_shape, block_halo=None):
space_shape(tuple): Shape of array to slice
block_shape(tuple): Size of each block to take
block_halo(tuple): Halo to tack on to each block
index(bool): Whether to provide an index for
each block
Returns:
collections.Sequence of \
Expand All @@ -39,9 +41,16 @@ def split_blocks(space_shape, block_shape, block_halo=None):
Examples:
>>> split_blocks(
... (2, 3,), (1, 1,), (1, 1,)
... (2, 3,), (1, 1,), (1, 1,), True
... ) #doctest: +NORMALIZE_WHITESPACE
([(slice(0, 1, 1), slice(0, 1, 1)),
([(0, 0),
(0, 1),
(0, 2),
(1, 0),
(1, 1),
(1, 2)],
<BLANKLINE>
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
Expand Down Expand Up @@ -74,6 +83,19 @@ def split_blocks(space_shape, block_shape, block_halo=None):
except ImportError:
ifilter, imap = filter, map

if index is None:
index = False
warnings.warn(
"`index` will default to `True` in the next minor release.",
FutureWarning
)
else:
warnings.warn(
"`index` will be deprecated in the next minor release. Once"
" removed, `split_blocks` will act as if `index` were `True`.",
PendingDeprecationWarning
)

if block_halo is not None:
if not (len(space_shape) == len(block_shape) == len(block_halo)):
raise ValueError(
Expand Down Expand Up @@ -200,8 +222,16 @@ def split_blocks(space_shape, block_shape, block_halo=None):
trimmed_halos_per_dim.append(a_trimmed_halo)

# Take all combinations of all ranges to get blocks.
result = tuple()
if index:
index_blocks = imap(lambda e: irange(len(e)), ranges_per_dim)
index_blocks = list(itertools.product(*index_blocks))
result += (index_blocks,)

orig_blocks = list(itertools.product(*ranges_per_dim))
haloed_blocks = list(itertools.product(*haloed_ranges_per_dim))
trimmed_halos = list(itertools.product(*trimmed_halos_per_dim))

return(orig_blocks, haloed_blocks, trimmed_halos)
result += (orig_blocks, haloed_blocks, trimmed_halos)

return result
110 changes: 110 additions & 0 deletions tests/test_blocks.py
Expand Up @@ -89,6 +89,15 @@ def test_split_blocks(self):
[(slice(0, 1, 1),), (slice(0, 1, 1),)])
)

result = blocks.split_blocks((2,), (1,), index=True)
self.assertEqual(
result,
([(0,), (1,)],
[(slice(0, 1, 1),), (slice(1, 2, 1),)],
[(slice(0, 1, 1),), (slice(1, 2, 1),)],
[(slice(0, 1, 1),), (slice(0, 1, 1),)])
)

result = blocks.split_blocks((2,), (-1,))
self.assertEqual(
result,
Expand Down Expand Up @@ -220,6 +229,107 @@ def test_split_blocks(self):
(slice(4, 7, 1), slice(3, 5, 1))])
)

result = blocks.split_blocks((10, 12,), (3, 2,), (4, 3,), True)
self.assertEqual(
result,
([(0, 0),
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(0, 5),
(1, 0),
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(1, 5),
(2, 0),
(2, 1),
(2, 2),
(2, 3),
(2, 4),
(2, 5),
(3, 0),
(3, 1),
(3, 2),
(3, 3),
(3, 4),
(3, 5)],
[(slice(0, 3, 1), slice(0, 2, 1)),
(slice(0, 3, 1), slice(2, 4, 1)),
(slice(0, 3, 1), slice(4, 6, 1)),
(slice(0, 3, 1), slice(6, 8, 1)),
(slice(0, 3, 1), slice(8, 10, 1)),
(slice(0, 3, 1), slice(10, 12, 1)),
(slice(3, 6, 1), slice(0, 2, 1)),
(slice(3, 6, 1), slice(2, 4, 1)),
(slice(3, 6, 1), slice(4, 6, 1)),
(slice(3, 6, 1), slice(6, 8, 1)),
(slice(3, 6, 1), slice(8, 10, 1)),
(slice(3, 6, 1), slice(10, 12, 1)),
(slice(6, 9, 1), slice(0, 2, 1)),
(slice(6, 9, 1), slice(2, 4, 1)),
(slice(6, 9, 1), slice(4, 6, 1)),
(slice(6, 9, 1), slice(6, 8, 1)),
(slice(6, 9, 1), slice(8, 10, 1)),
(slice(6, 9, 1), slice(10, 12, 1)),
(slice(9, 10, 1), slice(0, 2, 1)),
(slice(9, 10, 1), slice(2, 4, 1)),
(slice(9, 10, 1), slice(4, 6, 1)),
(slice(9, 10, 1), slice(6, 8, 1)),
(slice(9, 10, 1), slice(8, 10, 1)),
(slice(9, 10, 1), slice(10, 12, 1))],
[(slice(0, 7, 1), slice(0, 5, 1)),
(slice(0, 7, 1), slice(0, 7, 1)),
(slice(0, 7, 1), slice(1, 9, 1)),
(slice(0, 7, 1), slice(3, 11, 1)),
(slice(0, 7, 1), slice(5, 12, 1)),
(slice(0, 7, 1), slice(7, 12, 1)),
(slice(0, 10, 1), slice(0, 5, 1)),
(slice(0, 10, 1), slice(0, 7, 1)),
(slice(0, 10, 1), slice(1, 9, 1)),
(slice(0, 10, 1), slice(3, 11, 1)),
(slice(0, 10, 1), slice(5, 12, 1)),
(slice(0, 10, 1), slice(7, 12, 1)),
(slice(2, 10, 1), slice(0, 5, 1)),
(slice(2, 10, 1), slice(0, 7, 1)),
(slice(2, 10, 1), slice(1, 9, 1)),
(slice(2, 10, 1), slice(3, 11, 1)),
(slice(2, 10, 1), slice(5, 12, 1)),
(slice(2, 10, 1), slice(7, 12, 1)),
(slice(5, 10, 1), slice(0, 5, 1)),
(slice(5, 10, 1), slice(0, 7, 1)),
(slice(5, 10, 1), slice(1, 9, 1)),
(slice(5, 10, 1), slice(3, 11, 1)),
(slice(5, 10, 1), slice(5, 12, 1)),
(slice(5, 10, 1), slice(7, 12, 1))],
[(slice(0, 3, 1), slice(0, 2, 1)),
(slice(0, 3, 1), slice(2, 4, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(0, 2, 1)),
(slice(3, 6, 1), slice(2, 4, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(0, 2, 1)),
(slice(4, 7, 1), slice(2, 4, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(0, 2, 1)),
(slice(4, 7, 1), slice(2, 4, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1))])
)


def tearDown(self):
pass
Expand Down

0 comments on commit 9b83e47

Please sign in to comment.