From 9b83e47b9590289f976e590129e322822942476e Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Wed, 1 Mar 2017 12:58:05 -0500 Subject: [PATCH] split_blocks: Add optional block index. In some cases, it is handy to have a positional index for each block to know where blocks are relative to each other. Having these indices can be helpful when stitching the results back in some cases. This change adds a flag so that we can optionally get this index as the first argument. Also add a warning to indicate that this flag's default behavior is going to change to `True` in the next minor release. --- kenjutsu/blocks.py | 38 +++++++++++++-- tests/test_blocks.py | 110 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 4 deletions(-) diff --git a/kenjutsu/blocks.py b/kenjutsu/blocks.py index 6d5d5ae..f285dc4 100644 --- a/kenjutsu/blocks.py +++ b/kenjutsu/blocks.py @@ -12,7 +12,7 @@ import kenjutsu.format -def split_blocks(space_shape, block_shape, block_halo=None): +def split_blocks(space_shape, block_shape, block_halo=None, index=None): """ Return a list of slicings to cut each block out of an array or other. @@ -30,6 +30,8 @@ def split_blocks(space_shape, block_shape, block_halo=None): space_shape(tuple): Shape of array to slice block_shape(tuple): Size of each block to take block_halo(tuple): Halo to tack on to each block + index(bool): Whether to provide an index for + each block Returns: collections.Sequence of \ @@ -39,9 +41,16 @@ def split_blocks(space_shape, block_shape, block_halo=None): Examples: >>> split_blocks( - ... (2, 3,), (1, 1,), (1, 1,) + ... (2, 3,), (1, 1,), (1, 1,), True ... ) #doctest: +NORMALIZE_WHITESPACE - ([(slice(0, 1, 1), slice(0, 1, 1)), + ([(0, 0), + (0, 1), + (0, 2), + (1, 0), + (1, 1), + (1, 2)], + + [(slice(0, 1, 1), slice(0, 1, 1)), (slice(0, 1, 1), slice(1, 2, 1)), (slice(0, 1, 1), slice(2, 3, 1)), (slice(1, 2, 1), slice(0, 1, 1)), @@ -74,6 +83,19 @@ def split_blocks(space_shape, block_shape, block_halo=None): except ImportError: ifilter, imap = filter, map + if index is None: + index = False + warnings.warn( + "`index` will default to `True` in the next minor release.", + FutureWarning + ) + else: + warnings.warn( + "`index` will be deprecated in the next minor release. Once" + " removed, `split_blocks` will act as if `index` were `True`.", + PendingDeprecationWarning + ) + if block_halo is not None: if not (len(space_shape) == len(block_shape) == len(block_halo)): raise ValueError( @@ -200,8 +222,16 @@ def split_blocks(space_shape, block_shape, block_halo=None): trimmed_halos_per_dim.append(a_trimmed_halo) # Take all combinations of all ranges to get blocks. + result = tuple() + if index: + index_blocks = imap(lambda e: irange(len(e)), ranges_per_dim) + index_blocks = list(itertools.product(*index_blocks)) + result += (index_blocks,) + orig_blocks = list(itertools.product(*ranges_per_dim)) haloed_blocks = list(itertools.product(*haloed_ranges_per_dim)) trimmed_halos = list(itertools.product(*trimmed_halos_per_dim)) - return(orig_blocks, haloed_blocks, trimmed_halos) + result += (orig_blocks, haloed_blocks, trimmed_halos) + + return result diff --git a/tests/test_blocks.py b/tests/test_blocks.py index 5b2d103..3e93b11 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -89,6 +89,15 @@ def test_split_blocks(self): [(slice(0, 1, 1),), (slice(0, 1, 1),)]) ) + result = blocks.split_blocks((2,), (1,), index=True) + self.assertEqual( + result, + ([(0,), (1,)], + [(slice(0, 1, 1),), (slice(1, 2, 1),)], + [(slice(0, 1, 1),), (slice(1, 2, 1),)], + [(slice(0, 1, 1),), (slice(0, 1, 1),)]) + ) + result = blocks.split_blocks((2,), (-1,)) self.assertEqual( result, @@ -220,6 +229,107 @@ def test_split_blocks(self): (slice(4, 7, 1), slice(3, 5, 1))]) ) + result = blocks.split_blocks((10, 12,), (3, 2,), (4, 3,), True) + self.assertEqual( + result, + ([(0, 0), + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (0, 5), + (1, 0), + (1, 1), + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 0), + (2, 1), + (2, 2), + (2, 3), + (2, 4), + (2, 5), + (3, 0), + (3, 1), + (3, 2), + (3, 3), + (3, 4), + (3, 5)], + [(slice(0, 3, 1), slice(0, 2, 1)), + (slice(0, 3, 1), slice(2, 4, 1)), + (slice(0, 3, 1), slice(4, 6, 1)), + (slice(0, 3, 1), slice(6, 8, 1)), + (slice(0, 3, 1), slice(8, 10, 1)), + (slice(0, 3, 1), slice(10, 12, 1)), + (slice(3, 6, 1), slice(0, 2, 1)), + (slice(3, 6, 1), slice(2, 4, 1)), + (slice(3, 6, 1), slice(4, 6, 1)), + (slice(3, 6, 1), slice(6, 8, 1)), + (slice(3, 6, 1), slice(8, 10, 1)), + (slice(3, 6, 1), slice(10, 12, 1)), + (slice(6, 9, 1), slice(0, 2, 1)), + (slice(6, 9, 1), slice(2, 4, 1)), + (slice(6, 9, 1), slice(4, 6, 1)), + (slice(6, 9, 1), slice(6, 8, 1)), + (slice(6, 9, 1), slice(8, 10, 1)), + (slice(6, 9, 1), slice(10, 12, 1)), + (slice(9, 10, 1), slice(0, 2, 1)), + (slice(9, 10, 1), slice(2, 4, 1)), + (slice(9, 10, 1), slice(4, 6, 1)), + (slice(9, 10, 1), slice(6, 8, 1)), + (slice(9, 10, 1), slice(8, 10, 1)), + (slice(9, 10, 1), slice(10, 12, 1))], + [(slice(0, 7, 1), slice(0, 5, 1)), + (slice(0, 7, 1), slice(0, 7, 1)), + (slice(0, 7, 1), slice(1, 9, 1)), + (slice(0, 7, 1), slice(3, 11, 1)), + (slice(0, 7, 1), slice(5, 12, 1)), + (slice(0, 7, 1), slice(7, 12, 1)), + (slice(0, 10, 1), slice(0, 5, 1)), + (slice(0, 10, 1), slice(0, 7, 1)), + (slice(0, 10, 1), slice(1, 9, 1)), + (slice(0, 10, 1), slice(3, 11, 1)), + (slice(0, 10, 1), slice(5, 12, 1)), + (slice(0, 10, 1), slice(7, 12, 1)), + (slice(2, 10, 1), slice(0, 5, 1)), + (slice(2, 10, 1), slice(0, 7, 1)), + (slice(2, 10, 1), slice(1, 9, 1)), + (slice(2, 10, 1), slice(3, 11, 1)), + (slice(2, 10, 1), slice(5, 12, 1)), + (slice(2, 10, 1), slice(7, 12, 1)), + (slice(5, 10, 1), slice(0, 5, 1)), + (slice(5, 10, 1), slice(0, 7, 1)), + (slice(5, 10, 1), slice(1, 9, 1)), + (slice(5, 10, 1), slice(3, 11, 1)), + (slice(5, 10, 1), slice(5, 12, 1)), + (slice(5, 10, 1), slice(7, 12, 1))], + [(slice(0, 3, 1), slice(0, 2, 1)), + (slice(0, 3, 1), slice(2, 4, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(0, 2, 1)), + (slice(3, 6, 1), slice(2, 4, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(0, 2, 1)), + (slice(4, 7, 1), slice(2, 4, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(0, 2, 1)), + (slice(4, 7, 1), slice(2, 4, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1))]) + ) + def tearDown(self): pass