Skip to content

Commit

Permalink
split_blocks: Add optional block index.
Browse files Browse the repository at this point in the history
In some cases, it is handy to have a positional index for each block to
know where blocks are relative to each other. Having these indices can
be helpful when stitching the results back in some cases. This change
adds a flag so that we can optionally get this index as the first
argument.
  • Loading branch information
jakirkham committed Mar 1, 2017
1 parent 0d2b567 commit e97d938
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 4 deletions.
25 changes: 21 additions & 4 deletions kenjutsu/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import kenjutsu.format


def split_blocks(space_shape, block_shape, block_halo=None):
def split_blocks(space_shape, block_shape, block_halo=None, index=False):
"""
Return a list of slicings to cut each block out of an array or other.
Expand All @@ -30,6 +30,8 @@ def split_blocks(space_shape, block_shape, block_halo=None):
space_shape(tuple): Shape of array to slice
block_shape(tuple): Size of each block to take
block_halo(tuple): Halo to tack on to each block
index(bool): Whether to provide an index for
each block
Returns:
collections.Sequence of \
Expand All @@ -39,9 +41,16 @@ def split_blocks(space_shape, block_shape, block_halo=None):
Examples:
>>> split_blocks(
... (2, 3,), (1, 1,), (1, 1,)
... (2, 3,), (1, 1,), (1, 1,), True
... ) #doctest: +NORMALIZE_WHITESPACE
([(slice(0, 1, 1), slice(0, 1, 1)),
([(0, 0),
(0, 1),
(0, 2),
(1, 0),
(1, 1),
(1, 2)],
<BLANKLINE>
[(slice(0, 1, 1), slice(0, 1, 1)),
(slice(0, 1, 1), slice(1, 2, 1)),
(slice(0, 1, 1), slice(2, 3, 1)),
(slice(1, 2, 1), slice(0, 1, 1)),
Expand Down Expand Up @@ -182,8 +191,16 @@ def split_blocks(space_shape, block_shape, block_halo=None):
trimmed_halos_per_dim.append(a_trimmed_halo)

# Take all combinations of all ranges to get blocks.
result = tuple()
if index:
index_blocks = imap(lambda e: irange(len(e)), ranges_per_dim)
index_blocks = list(itertools.product(*index_blocks))
result += (index_blocks,)

orig_blocks = list(itertools.product(*ranges_per_dim))
haloed_blocks = list(itertools.product(*haloed_ranges_per_dim))
trimmed_halos = list(itertools.product(*trimmed_halos_per_dim))

return(orig_blocks, haloed_blocks, trimmed_halos)
result += (orig_blocks, haloed_blocks, trimmed_halos)

return result
101 changes: 101 additions & 0 deletions tests/test_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,107 @@ def test_split_blocks(self):
(slice(4, 7, 1), slice(3, 5, 1))])
)

result = blocks.split_blocks((10, 12,), (3, 2,), (4, 3,), True)
self.assertEqual(
result,
([(0, 0),
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(0, 5),
(1, 0),
(1, 1),
(1, 2),
(1, 3),
(1, 4),
(1, 5),
(2, 0),
(2, 1),
(2, 2),
(2, 3),
(2, 4),
(2, 5),
(3, 0),
(3, 1),
(3, 2),
(3, 3),
(3, 4),
(3, 5)],
[(slice(0, 3, 1), slice(0, 2, 1)),
(slice(0, 3, 1), slice(2, 4, 1)),
(slice(0, 3, 1), slice(4, 6, 1)),
(slice(0, 3, 1), slice(6, 8, 1)),
(slice(0, 3, 1), slice(8, 10, 1)),
(slice(0, 3, 1), slice(10, 12, 1)),
(slice(3, 6, 1), slice(0, 2, 1)),
(slice(3, 6, 1), slice(2, 4, 1)),
(slice(3, 6, 1), slice(4, 6, 1)),
(slice(3, 6, 1), slice(6, 8, 1)),
(slice(3, 6, 1), slice(8, 10, 1)),
(slice(3, 6, 1), slice(10, 12, 1)),
(slice(6, 9, 1), slice(0, 2, 1)),
(slice(6, 9, 1), slice(2, 4, 1)),
(slice(6, 9, 1), slice(4, 6, 1)),
(slice(6, 9, 1), slice(6, 8, 1)),
(slice(6, 9, 1), slice(8, 10, 1)),
(slice(6, 9, 1), slice(10, 12, 1)),
(slice(9, 10, 1), slice(0, 2, 1)),
(slice(9, 10, 1), slice(2, 4, 1)),
(slice(9, 10, 1), slice(4, 6, 1)),
(slice(9, 10, 1), slice(6, 8, 1)),
(slice(9, 10, 1), slice(8, 10, 1)),
(slice(9, 10, 1), slice(10, 12, 1))],
[(slice(0, 7, 1), slice(0, 5, 1)),
(slice(0, 7, 1), slice(0, 7, 1)),
(slice(0, 7, 1), slice(1, 9, 1)),
(slice(0, 7, 1), slice(3, 11, 1)),
(slice(0, 7, 1), slice(5, 12, 1)),
(slice(0, 7, 1), slice(7, 12, 1)),
(slice(0, 10, 1), slice(0, 5, 1)),
(slice(0, 10, 1), slice(0, 7, 1)),
(slice(0, 10, 1), slice(1, 9, 1)),
(slice(0, 10, 1), slice(3, 11, 1)),
(slice(0, 10, 1), slice(5, 12, 1)),
(slice(0, 10, 1), slice(7, 12, 1)),
(slice(2, 10, 1), slice(0, 5, 1)),
(slice(2, 10, 1), slice(0, 7, 1)),
(slice(2, 10, 1), slice(1, 9, 1)),
(slice(2, 10, 1), slice(3, 11, 1)),
(slice(2, 10, 1), slice(5, 12, 1)),
(slice(2, 10, 1), slice(7, 12, 1)),
(slice(5, 10, 1), slice(0, 5, 1)),
(slice(5, 10, 1), slice(0, 7, 1)),
(slice(5, 10, 1), slice(1, 9, 1)),
(slice(5, 10, 1), slice(3, 11, 1)),
(slice(5, 10, 1), slice(5, 12, 1)),
(slice(5, 10, 1), slice(7, 12, 1))],
[(slice(0, 3, 1), slice(0, 2, 1)),
(slice(0, 3, 1), slice(2, 4, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(0, 3, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(0, 2, 1)),
(slice(3, 6, 1), slice(2, 4, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(3, 6, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(0, 2, 1)),
(slice(4, 7, 1), slice(2, 4, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(0, 2, 1)),
(slice(4, 7, 1), slice(2, 4, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1)),
(slice(4, 7, 1), slice(3, 5, 1))])
)


def tearDown(self):
pass
Expand Down

0 comments on commit e97d938

Please sign in to comment.