diff --git a/kenjutsu/blocks.py b/kenjutsu/blocks.py index 6d5d5ae..f285dc4 100644 --- a/kenjutsu/blocks.py +++ b/kenjutsu/blocks.py @@ -12,7 +12,7 @@ import kenjutsu.format -def split_blocks(space_shape, block_shape, block_halo=None): +def split_blocks(space_shape, block_shape, block_halo=None, index=None): """ Return a list of slicings to cut each block out of an array or other. @@ -30,6 +30,8 @@ def split_blocks(space_shape, block_shape, block_halo=None): space_shape(tuple): Shape of array to slice block_shape(tuple): Size of each block to take block_halo(tuple): Halo to tack on to each block + index(bool): Whether to provide an index for + each block Returns: collections.Sequence of \ @@ -39,9 +41,16 @@ def split_blocks(space_shape, block_shape, block_halo=None): Examples: >>> split_blocks( - ... (2, 3,), (1, 1,), (1, 1,) + ... (2, 3,), (1, 1,), (1, 1,), True ... ) #doctest: +NORMALIZE_WHITESPACE - ([(slice(0, 1, 1), slice(0, 1, 1)), + ([(0, 0), + (0, 1), + (0, 2), + (1, 0), + (1, 1), + (1, 2)], + + [(slice(0, 1, 1), slice(0, 1, 1)), (slice(0, 1, 1), slice(1, 2, 1)), (slice(0, 1, 1), slice(2, 3, 1)), (slice(1, 2, 1), slice(0, 1, 1)), @@ -74,6 +83,19 @@ def split_blocks(space_shape, block_shape, block_halo=None): except ImportError: ifilter, imap = filter, map + if index is None: + index = False + warnings.warn( + "`index` will default to `True` in the next minor release.", + FutureWarning + ) + else: + warnings.warn( + "`index` will be deprecated in the next minor release. Once" + " removed, `split_blocks` will act as if `index` were `True`.", + PendingDeprecationWarning + ) + if block_halo is not None: if not (len(space_shape) == len(block_shape) == len(block_halo)): raise ValueError( @@ -200,8 +222,16 @@ def split_blocks(space_shape, block_shape, block_halo=None): trimmed_halos_per_dim.append(a_trimmed_halo) # Take all combinations of all ranges to get blocks. + result = tuple() + if index: + index_blocks = imap(lambda e: irange(len(e)), ranges_per_dim) + index_blocks = list(itertools.product(*index_blocks)) + result += (index_blocks,) + orig_blocks = list(itertools.product(*ranges_per_dim)) haloed_blocks = list(itertools.product(*haloed_ranges_per_dim)) trimmed_halos = list(itertools.product(*trimmed_halos_per_dim)) - return(orig_blocks, haloed_blocks, trimmed_halos) + result += (orig_blocks, haloed_blocks, trimmed_halos) + + return result diff --git a/tests/test_blocks.py b/tests/test_blocks.py index 5b2d103..3e93b11 100644 --- a/tests/test_blocks.py +++ b/tests/test_blocks.py @@ -89,6 +89,15 @@ def test_split_blocks(self): [(slice(0, 1, 1),), (slice(0, 1, 1),)]) ) + result = blocks.split_blocks((2,), (1,), index=True) + self.assertEqual( + result, + ([(0,), (1,)], + [(slice(0, 1, 1),), (slice(1, 2, 1),)], + [(slice(0, 1, 1),), (slice(1, 2, 1),)], + [(slice(0, 1, 1),), (slice(0, 1, 1),)]) + ) + result = blocks.split_blocks((2,), (-1,)) self.assertEqual( result, @@ -220,6 +229,107 @@ def test_split_blocks(self): (slice(4, 7, 1), slice(3, 5, 1))]) ) + result = blocks.split_blocks((10, 12,), (3, 2,), (4, 3,), True) + self.assertEqual( + result, + ([(0, 0), + (0, 1), + (0, 2), + (0, 3), + (0, 4), + (0, 5), + (1, 0), + (1, 1), + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (2, 0), + (2, 1), + (2, 2), + (2, 3), + (2, 4), + (2, 5), + (3, 0), + (3, 1), + (3, 2), + (3, 3), + (3, 4), + (3, 5)], + [(slice(0, 3, 1), slice(0, 2, 1)), + (slice(0, 3, 1), slice(2, 4, 1)), + (slice(0, 3, 1), slice(4, 6, 1)), + (slice(0, 3, 1), slice(6, 8, 1)), + (slice(0, 3, 1), slice(8, 10, 1)), + (slice(0, 3, 1), slice(10, 12, 1)), + (slice(3, 6, 1), slice(0, 2, 1)), + (slice(3, 6, 1), slice(2, 4, 1)), + (slice(3, 6, 1), slice(4, 6, 1)), + (slice(3, 6, 1), slice(6, 8, 1)), + (slice(3, 6, 1), slice(8, 10, 1)), + (slice(3, 6, 1), slice(10, 12, 1)), + (slice(6, 9, 1), slice(0, 2, 1)), + (slice(6, 9, 1), slice(2, 4, 1)), + (slice(6, 9, 1), slice(4, 6, 1)), + (slice(6, 9, 1), slice(6, 8, 1)), + (slice(6, 9, 1), slice(8, 10, 1)), + (slice(6, 9, 1), slice(10, 12, 1)), + (slice(9, 10, 1), slice(0, 2, 1)), + (slice(9, 10, 1), slice(2, 4, 1)), + (slice(9, 10, 1), slice(4, 6, 1)), + (slice(9, 10, 1), slice(6, 8, 1)), + (slice(9, 10, 1), slice(8, 10, 1)), + (slice(9, 10, 1), slice(10, 12, 1))], + [(slice(0, 7, 1), slice(0, 5, 1)), + (slice(0, 7, 1), slice(0, 7, 1)), + (slice(0, 7, 1), slice(1, 9, 1)), + (slice(0, 7, 1), slice(3, 11, 1)), + (slice(0, 7, 1), slice(5, 12, 1)), + (slice(0, 7, 1), slice(7, 12, 1)), + (slice(0, 10, 1), slice(0, 5, 1)), + (slice(0, 10, 1), slice(0, 7, 1)), + (slice(0, 10, 1), slice(1, 9, 1)), + (slice(0, 10, 1), slice(3, 11, 1)), + (slice(0, 10, 1), slice(5, 12, 1)), + (slice(0, 10, 1), slice(7, 12, 1)), + (slice(2, 10, 1), slice(0, 5, 1)), + (slice(2, 10, 1), slice(0, 7, 1)), + (slice(2, 10, 1), slice(1, 9, 1)), + (slice(2, 10, 1), slice(3, 11, 1)), + (slice(2, 10, 1), slice(5, 12, 1)), + (slice(2, 10, 1), slice(7, 12, 1)), + (slice(5, 10, 1), slice(0, 5, 1)), + (slice(5, 10, 1), slice(0, 7, 1)), + (slice(5, 10, 1), slice(1, 9, 1)), + (slice(5, 10, 1), slice(3, 11, 1)), + (slice(5, 10, 1), slice(5, 12, 1)), + (slice(5, 10, 1), slice(7, 12, 1))], + [(slice(0, 3, 1), slice(0, 2, 1)), + (slice(0, 3, 1), slice(2, 4, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(0, 3, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(0, 2, 1)), + (slice(3, 6, 1), slice(2, 4, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(3, 6, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(0, 2, 1)), + (slice(4, 7, 1), slice(2, 4, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(0, 2, 1)), + (slice(4, 7, 1), slice(2, 4, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1)), + (slice(4, 7, 1), slice(3, 5, 1))]) + ) + def tearDown(self): pass