Skip to content

Commit

Permalink
Fix unify_chunks to return regular chunks in all cases. (#470)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Jun 5, 2024
1 parent 101d59b commit 4eb886f
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 72 deletions.
36 changes: 33 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import numpy as np
import zarr
from tlz import concat, partition
from tlz import concat, first, partition
from toolz import accumulate, map
from zarr.indexing import (
IntDimIndexer,
Expand All @@ -37,7 +37,7 @@
offset_to_block_id,
to_chunksize,
)
from cubed.vendor.dask.array.core import common_blockdim, normalize_chunks
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.array.utils import validate_axis
from cubed.vendor.dask.blockwise import broadcast_dimensions, lol_product
from cubed.vendor.dask.utils import has_keyword
Expand Down Expand Up @@ -1383,7 +1383,9 @@ def unify_chunks(*args: "Array", **kwargs):
else:
nameinds.append((a, ind))

chunkss = broadcast_dimensions(nameinds, blockdim_dict, consolidate=common_blockdim)
chunkss = broadcast_dimensions(
nameinds, blockdim_dict, consolidate=smallest_blockdim
)

arrays = []
for a, i in arginds:
Expand All @@ -1400,8 +1402,36 @@ def unify_chunks(*args: "Array", **kwargs):
)
if chunks != a.chunks and all(a.chunks):
# this will raise if chunks are not regular
# but this should never happen with smallest_blockdim
chunksize = to_chunksize(chunks)
arrays.append(rechunk(a, chunksize))
else:
arrays.append(a)
return chunkss, arrays


def smallest_blockdim(blockdims):
"""Find the smallest block dimensions from the list of block dimensions
Unlike Dask's common_blockdim, this returns regular chunks (assuming
regular chunks are passed in).
"""
if not any(blockdims):
return ()
non_trivial_dims = {d for d in blockdims if len(d) > 1}
if len(non_trivial_dims) == 1:
return first(non_trivial_dims)
if len(non_trivial_dims) == 0:
return max(blockdims, key=first)

if len(set(map(sum, non_trivial_dims))) > 1:
raise ValueError("Chunks do not add up to same value", blockdims)

# find dims with the smallest first chunk
m = -1
out = None
for ntd in non_trivial_dims:
if m == -1 or ntd[0] < m:
m = ntd[0]
out = ntd
return out
9 changes: 8 additions & 1 deletion cubed/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def test_add(spec, any_executor):
)


def test_add_with_broadcast(spec, executor):
def test_add_different_chunks(spec, executor):
a = xp.ones((10, 10), chunks=(10, 2), spec=spec)
b = xp.ones((10, 10), chunks=(2, 10), spec=spec)
c = xp.add(a, b)
Expand All @@ -178,6 +178,13 @@ def test_add_with_broadcast(spec, executor):
)


def test_add_different_chunks_fail(spec, executor):
a = xp.ones((10,), chunks=(3,), spec=spec)
b = xp.ones((10,), chunks=(5,), spec=spec)
c = xp.add(a, b)
assert_array_equal(c.compute(executor=executor), np.ones((10,)) + np.ones((10,)))


def test_equal(spec):
a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
b = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec)
Expand Down
108 changes: 108 additions & 0 deletions cubed/tests/test_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import numpy as np
import pytest
from numpy.testing import assert_array_equal

import cubed.array_api as xp
from cubed.core.ops import smallest_blockdim, unify_chunks
from cubed.tests.utils import TaskCounter


def test_smallest_blockdim():
assert smallest_blockdim([]) == ()
assert smallest_blockdim([(5,), (5,)]) == (5,)
assert smallest_blockdim([(5,), (3, 2)]) == (3, 2)
assert smallest_blockdim([(5, 5), (3, 3, 3, 1)]) == (3, 3, 3, 1)
assert smallest_blockdim([(2, 1), (2, 1)]) == (2, 1)
assert smallest_blockdim([(2, 2, 1), (3, 2), (2, 2, 1)]) == (2, 2, 1)

with pytest.raises(ValueError, match="Chunks do not add up to same value"):
smallest_blockdim([(2, 1), (2, 2)])


@pytest.mark.parametrize(
"chunks_a, chunks_b, expected_chunksize",
[
((2,), (4,), (2,)),
((4,), (2,), (2,)),
((6,), (10,), (6,)),
((10,), (10,), (10,)),
((5,), (10,), (5,)),
((3,), (5,), (3,)),
((5,), (3,), (3,)),
],
)
def test_unify_chunks_elemwise(chunks_a, chunks_b, expected_chunksize):
a = xp.ones((10,), chunks=chunks_a)
b = xp.ones((10,), chunks=chunks_b)

_, arrays = unify_chunks(a, "i", b, "i")
for arr in arrays:
assert arr.chunksize == expected_chunksize

c = xp.add(a, b)
assert_array_equal(c.compute(), np.ones((10,)) + np.ones((10,)))


@pytest.mark.parametrize(
"chunks_a, chunks_b, expected_chunksize",
[
((2, 2), (4, 4), (2, 2)),
((2, 4), (4, 2), (2, 2)),
((4, 2), (2, 4), (2, 2)),
((3, 5), (5, 3), (3, 3)),
((3, 10), (10, 3), (3, 3)),
],
)
def test_unify_chunks_elemwise_2d(chunks_a, chunks_b, expected_chunksize):
a = xp.ones((10, 10), chunks=chunks_a)
b = xp.ones((10, 10), chunks=chunks_b)

_, arrays = unify_chunks(a, "ij", b, "ij")
for arr in arrays:
assert arr.chunksize == expected_chunksize

c = xp.add(a, b)
assert_array_equal(c.compute(), np.ones((10, 10)) + np.ones((10, 10)))


@pytest.mark.parametrize(
"chunks_a, chunks_b, expected_chunksize",
[
((2, 2), (4, 4), (2, 2)),
((2, 4), (2, 4), (2, 2)),
((4, 2), (4, 2), (2, 2)),
((3, 5), (3, 5), (3, 3)),
((3, 10), (3, 10), (3, 3)),
],
)
def test_unify_chunks_blockwise_2d(chunks_a, chunks_b, expected_chunksize):
a = xp.ones((10, 10), chunks=chunks_a)
b = xp.ones((10, 10), chunks=chunks_b)

_, arrays = unify_chunks(a, "ij", b, "ji")
for arr in arrays:
assert arr.chunksize == expected_chunksize

c = xp.matmul(a, b)
assert_array_equal(c.compute(), np.matmul(np.ones((10, 10)), np.ones((10, 10))))


def test_unify_chunks_broadcast_scalar():
a = xp.ones((10,), chunks=(3,))
b = a + 1
assert_array_equal(b.compute(), np.ones((10,)) + 1)


def test_unify_chunks_broadcast_2d():
a = xp.ones((10, 10), chunks=(3, 3))
b = xp.ones((10,), chunks=(5,))
c = xp.add(a, b)

# the following checks that b is rechunked *before* broadcasting, to avoid materializing the full (broadcasted) array
task_counter = TaskCounter()
res = c.compute(callbacks=[task_counter])
num_created_arrays = 2 # b rechunked, c
# 1 task for rechunk of b, 16 for addition operation
assert task_counter.value == num_created_arrays + 1 + 16

assert_array_equal(res, np.ones((10, 10)) + np.ones((10,)))
68 changes: 0 additions & 68 deletions cubed/vendor/dask/array/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,71 +462,3 @@ def _check_regular_chunks(chunkset):
if chunks[-1] > chunks[0]:
return False
return True


def common_blockdim(blockdims):
"""Find the common block dimensions from the list of block dimensions
Currently only implements the simplest possible heuristic: the common
block-dimension is the only one that does not span fully span a dimension.
This is a conservative choice that allows us to avoid potentially very
expensive rechunking.
Assumes that each element of the input block dimensions has all the same
sum (i.e., that they correspond to dimensions of the same size).
Examples
--------
>>> common_blockdim([(3,), (2, 1)])
(2, 1)
>>> common_blockdim([(1, 2), (2, 1)])
(1, 1, 1)
>>> common_blockdim([(2, 2), (3, 1)]) # doctest: +SKIP
Traceback (most recent call last):
...
ValueError: Chunks do not align
"""
if not any(blockdims):
return ()
non_trivial_dims = {d for d in blockdims if len(d) > 1}
if len(non_trivial_dims) == 1:
return first(non_trivial_dims)
if len(non_trivial_dims) == 0:
return max(blockdims, key=first)

if np.isnan(sum(map(sum, blockdims))):
raise ValueError(
"Arrays' chunk sizes (%s) are unknown.\n\n"
"A possible solution:\n"
" x.compute_chunk_sizes()" % blockdims
)

if len(set(map(sum, non_trivial_dims))) > 1:
raise ValueError("Chunks do not add up to same value", blockdims)

# We have multiple non-trivial chunks on this axis
# e.g. (5, 2) and (4, 3)

# We create a single chunk tuple with the same total length
# that evenly divides both, e.g. (4, 1, 2)

# To accomplish this we walk down all chunk tuples together, finding the
# smallest element, adding it to the output, and subtracting it from all
# other elements and remove the element itself. We stop once we have
# burned through all of the chunk tuples.
# For efficiency's sake we reverse the lists so that we can pop off the end
rchunks = [list(ntd)[::-1] for ntd in non_trivial_dims]
total = sum(first(non_trivial_dims))
i = 0

out = []
while i < total:
m = min(c[-1] for c in rchunks)
out.append(m)
for c in rchunks:
c[-1] -= m
if c[-1] == 0:
c.pop()
i += m

return tuple(out)

0 comments on commit 4eb886f

Please sign in to comment.