Skip to content

Commit

Permalink
Take broadcast trick into account for array chunk memory (#447)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Apr 22, 2024
1 parent 1516ec5 commit 0f3021a
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 21 deletions.
4 changes: 2 additions & 2 deletions cubed/core/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from cubed.runtime.types import Callback, Executor
from cubed.spec import Spec, spec_from_config
from cubed.storage.zarr import open_if_lazy_zarr_array
from cubed.utils import chunk_memory
from cubed.utils import array_memory
from cubed.vendor.dask.array.core import normalize_chunks

from .plan import arrays_to_plan
Expand Down Expand Up @@ -60,7 +60,7 @@ def zarray(self):
@property
def chunkmem(self):
"""Amount of memory in bytes that a single chunk uses."""
return chunk_memory(self.dtype, self.chunksize)
return array_memory(self.dtype, self.chunksize)

@property
def chunksize(self):
Expand Down
6 changes: 3 additions & 3 deletions cubed/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from cubed.spec import spec_from_config
from cubed.utils import (
_concatenate2,
chunk_memory,
array_memory,
get_item,
offset_to_block_id,
to_chunksize,
Expand Down Expand Up @@ -966,7 +966,7 @@ def reduction(
while any(n > 1 for i, n in enumerate(result.numblocks) if i in axis):
# merge along axis
target_chunks = list(result.chunksize)
chunk_mem = chunk_memory(intermediate_dtype, result.chunksize)
chunk_mem = array_memory(intermediate_dtype, result.chunksize)
for i, s in enumerate(result.shape):
if i in axis:
assert result.chunksize[i] == 1 # result of reduction
Expand Down Expand Up @@ -1229,7 +1229,7 @@ def key_function(out_key):
# to stay within limits (maybe because the iterator doesn't free the previous object
# before getting the next). We also need extra memory to hold two reduced chunks, since
# they are concatenated two at a time.
extra_projected_mem = x.chunkmem + 2 * chunk_memory(dtype, to_chunksize(chunks))
extra_projected_mem = x.chunkmem + 2 * array_memory(dtype, to_chunksize(chunks))

return general_blockwise(
_partial_reduce,
Expand Down
2 changes: 1 addition & 1 deletion cubed/core/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def visualize(

elif node_type == "array":
target = d["target"]
chunkmem = memory_repr(chunk_memory(target.dtype, target.chunks))
chunkmem = memory_repr(chunk_memory(target))
nbytes = None

# materialized arrays are light orange, virtual arrays are white
Expand Down
15 changes: 11 additions & 4 deletions cubed/primitive/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from cubed.runtime.types import CubedPipeline
from cubed.storage.zarr import T_ZarrArray, lazy_zarr_array
from cubed.types import T_Chunks, T_DType, T_Shape, T_Store
from cubed.utils import chunk_memory, get_item, map_nested, split_into, to_chunksize
from cubed.utils import (
array_memory,
chunk_memory,
get_item,
map_nested,
split_into,
to_chunksize,
)
from cubed.vendor.dask.array.core import normalize_chunks
from cubed.vendor.dask.blockwise import _get_coord_mapping, _make_dims, lol_product
from cubed.vendor.dask.core import flatten
Expand Down Expand Up @@ -316,12 +323,12 @@ def general_blockwise(
# - we assume compression has no effect (so it's an overestimate)
# - ideally we'd be able to look at nbytes_stored,
# but this is not possible in general since the array has not been written yet
projected_mem += chunk_memory(array.dtype, array.chunks) * 2
projected_mem += array_memory(array.dtype, array.chunks) * 2
# output
# memory for a compressed and an uncompressed output array chunk
# - this assumes the blockwise function creates a new array)
# - numcodecs uses a working output buffer that's the size of the array being compressed
projected_mem += chunk_memory(dtype, chunksize) * 2
projected_mem += array_memory(dtype, chunksize) * 2

if projected_mem > allowed_mem:
raise ValueError(
Expand Down Expand Up @@ -450,7 +457,7 @@ def peak_projected_mem(primitive_ops):
memory_modeller = MemoryModeller()
for p in primitive_ops:
memory_modeller.allocate(p.projected_mem)
chunkmem = chunk_memory(p.target_array.dtype, p.target_array.chunks)
chunkmem = chunk_memory(p.target_array)
memory_modeller.free(p.projected_mem - chunkmem)
return memory_modeller.peak_mem

Expand Down
12 changes: 11 additions & 1 deletion cubed/storage/virtual.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from cubed.backend_array_api import namespace as nxp
from cubed.backend_array_api import numpy_array_to_backend_array
from cubed.types import T_DType, T_RegularChunks, T_Shape
from cubed.utils import broadcast_trick, memory_repr
from cubed.utils import array_memory, broadcast_trick, memory_repr


class VirtualEmptyArray:
Expand Down Expand Up @@ -37,6 +37,11 @@ def __getitem__(self, key):
# use broadcast trick so array chunks only occupy a single value in memory
return broadcast_trick(nxp.empty)(indexer.shape, dtype=self.dtype)

@property
def chunkmem(self):
# take broadcast trick into account
return array_memory(self.dtype, (1,))

@property
def oindex(self):
return self.template.oindex
Expand Down Expand Up @@ -75,6 +80,11 @@ def __getitem__(self, key):
indexer.shape, fill_value=self.fill_value, dtype=self.dtype
)

@property
def chunkmem(self):
# take broadcast trick into account
return array_memory(self.dtype, (1,))

@property
def oindex(self):
return self.template.oindex
Expand Down
12 changes: 6 additions & 6 deletions cubed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from cubed.backend_array_api import namespace as nxp
from cubed.utils import (
array_memory,
block_id_to_offset,
broadcast_trick,
chunk_memory,
extract_stack_summaries,
join_path,
map_nested,
Expand All @@ -22,11 +22,11 @@
)


def test_chunk_memory():
assert chunk_memory(np.int64, (3,)) == 24
assert chunk_memory(np.int32, (3,)) == 12
assert chunk_memory(np.int32, (3, 5)) == 60
assert chunk_memory(np.int32, (0,)) == 0
def test_array_memory():
assert array_memory(np.int64, (3,)) == 24
assert array_memory(np.int32, (3,)) == 12
assert array_memory(np.int32, (3, 5)) == 60
assert array_memory(np.int32, (0,)) == 0


def test_block_id_to_offset():
Expand Down
18 changes: 14 additions & 4 deletions cubed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,25 @@
import tlz as toolz

from cubed.backend_array_api import namespace as nxp
from cubed.types import T_DType, T_RectangularChunks, T_RegularChunks
from cubed.vendor.dask.array.core import _check_regular_chunks
from cubed.types import T_DType, T_RectangularChunks, T_RegularChunks, T_Shape
from cubed.vendor.dask.array.core import _check_regular_chunks, normalize_chunks

PathType = Union[str, Path]


def chunk_memory(dtype: T_DType, chunksize: T_RegularChunks) -> int:
def array_memory(dtype: T_DType, shape: T_Shape) -> int:
"""Calculate the amount of memory in bytes that an array uses."""
return np.dtype(dtype).itemsize * prod(shape)


def chunk_memory(arr) -> int:
"""Calculate the amount of memory in bytes that a single chunk uses."""
return np.dtype(dtype).itemsize * prod(chunksize)
if hasattr(arr, "chunkmem"):
return arr.chunkmem
return array_memory(
arr.dtype,
to_chunksize(normalize_chunks(arr.chunks, shape=arr.shape, dtype=arr.dtype)),
)


def offset_to_block_id(offset: int, numblocks: Tuple[int, ...]) -> Tuple[int, ...]:
Expand Down

0 comments on commit 0f3021a

Please sign in to comment.