Skip to content

Commit

Permalink
Rewrite BlockwiseDep handling of map_blocks
Browse files Browse the repository at this point in the history
This is an attempt to improve the serialization situation. Instead of
trying to put the data into a custom subclass of BlockwiseDep, insert it
as (constant) arguments to the wrapper function. This then relies on
SubgraphCallable to handle the serialization.

This may theoretically improve serialization costs when the graph is
materialized on the client (which I think is still the default for
arrays), because the raw data is inside the SubgraphCallable and hence
only serialised once, with production of the individual block_infos left
to the workers.

The benchmark code in #7686 is a bit slower than the previous version
(about 7s for compute), but still faster than main.
  • Loading branch information
bmerry committed May 26, 2021
1 parent bf0c01b commit bc2b235
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 94 deletions.
115 changes: 26 additions & 89 deletions dask/array/core.py
Expand Up @@ -14,7 +14,7 @@
from numbers import Integral, Number
from operator import add, getitem, mul
from threading import Lock
from typing import Dict, List
from typing import Dict
from typing import Mapping as TypingMapping
from typing import Optional, Tuple

Expand All @@ -32,7 +32,7 @@
persist,
tokenize,
)
from ..blockwise import BlockwiseDep, BlockwiseDepDict, broadcast_dimensions
from ..blockwise import BlockwiseDepDict, broadcast_dimensions
from ..context import globalmethod
from ..core import quote
from ..delayed import Delayed, delayed
Expand Down Expand Up @@ -446,7 +446,7 @@ def __init__(
self.starts = starts
self.num_chunks = tuple(len(s) - 1 for s in self.starts)

def __getitem__(self, idx: Tuple[int, ...]) -> dict:
def __call__(self, idx: Tuple[int, ...]) -> dict:
location = tuple((idx[c] if c >= 0 else 0) for c in self.idx_remap)
return {
"shape": self.shape,
Expand All @@ -455,108 +455,52 @@ def __getitem__(self, idx: Tuple[int, ...]) -> dict:
"chunk-location": location,
}

def __dask_distributed_pack__(self):
return {
"idx_remap": self.idx_remap,
"shape": self.shape,
"starts": self.starts,
}

@classmethod
def __dask_distributed_unpack__(cls, state):
# msgpack turns tuples into lists, so we have to convert back
return cls(
tuple(state["idx_remap"]),
tuple(state["shape"]),
tuple((tuple(s) for s in state["starts"])),
)


class _BlockInfoOutput:
def __init__(self, shape: Tuple[int, ...], starts: Tuple[int, ...], meta) -> None:
def __init__(self, shape: Tuple[int, ...], starts: Tuple[int, ...], dtype) -> None:
self.shape = shape
self.starts = starts
self.meta = meta
self.dtype = dtype
self.num_chunks = tuple(len(s) - 1 for s in self.starts)

def __getitem__(self, idx: Tuple[int, ...]) -> dict:
def __call__(self, idx: Tuple[int, ...]) -> dict:
return {
"shape": self.shape,
"num-chunks": self.num_chunks,
"array-location": [(s[i], s[i + 1]) for s, i in zip(self.starts, idx)],
"chunk-location": idx,
"chunk-shape": tuple(s[i + 1] - s[i] for s, i in zip(self.starts, idx)),
"dtype": self.meta.dtype,
}

def __dask_distributed_pack__(self):
from distributed.protocol import to_serialize

return {
"shape": self.shape,
"starts": self.starts,
"meta": to_serialize(self.meta),
"dtype": self.dtype,
}

@classmethod
def __dask_distributed_unpack__(cls, state):
from distributed.protocol import deserialize

# msgpack turns tuples into lists, so we have to convert back
return cls(
tuple(state["shape"]),
tuple(tuple(s) for s in state["starts"]),
deserialize(state["meta"].header, state["meta"].frames),
)


class _BlockInfo(BlockwiseDep):
class _BlockInfo:
"""Generate ``block_info`` parameters for :func:`map_blocks` on the fly."""

def __init__(
self, output: _BlockInfoOutput, inputs: TypingMapping[int, _BlockInfoInput]
):
self.numblocks = output.num_chunks
self.produces_tasks = False
self.output = output
self.inputs = inputs

def __getitem__(self, idx: Tuple[int, ...]) -> Dict[Optional[str], dict]:
info = {key: array_info[idx] for key, array_info in self.inputs.items()}
info[None] = self.output[idx]
def __call__(self, idx: Tuple[int, ...]) -> Dict[Optional[str], dict]:
info = {key: array_info(idx) for key, array_info in self.inputs.items()}
info[None] = self.output(idx)
return info

def __dask_distributed_pack__(
self, required_indices: Optional[List[Tuple[int, ...]]] = None
):
return {
"output": self.output.__dask_distributed_pack__(),
"inputs": {
name: value.__dask_distributed_pack__()
for name, value in self.inputs.items()
},
}

@classmethod
def __dask_distributed_unpack__(cls, state):
return cls(
_BlockInfoOutput.__dask_distributed_unpack__(state["output"]),
{
name: _BlockInfoInput.__dask_distributed_unpack__(value)
for name, value in state["inputs"].items()
},
)


def _pass_extra_kwargs(func, keys, *args, **kwargs):
def _pass_block_info(func, block_info, block_id, *args, **kwargs):
"""Helper for :func:`map_blocks` to pass `block_info` or `block_id`.
For each element of `keys`, a corresponding element of args is changed
to a keyword argument with that key, before all arguments are passed on
to `func`.
"""
kwargs.update(zip(keys, args))
return func(*args[len(keys) :], **kwargs)
if has_keyword(func, "block_id"):
kwargs.update(block_id=block_id)
if has_keyword(func, "block_info"):
kwargs.update(block_info=block_info(block_id))
return func(*args, **kwargs)


def map_blocks(
Expand Down Expand Up @@ -841,15 +785,7 @@ def map_blocks(
**kwargs,
)

extra_argpairs = []
extra_names = []
# If func has block_id as an argument, construct an array of block IDs and
# prepare to inject it.
if has_keyword(func, "block_id"):
# We don't need to provide any keys to BlockwiseDepDict, because the
# default for missing keys is to provide the block ID.
extra_argpairs.append((BlockwiseDepDict({}, numblocks=out.numblocks), out_ind))
extra_names.append("block_id")
block_info = None

# If func has block_info as an argument, construct an array of block info
# objects and prepare to inject it.
Expand Down Expand Up @@ -883,25 +819,26 @@ def map_blocks(
block_info_output = _BlockInfoOutput(
out.shape,
[cached_cumsum(c, initial_zero=True) for c in out.chunks],
out._meta,
out.dtype,
)
block_info = _BlockInfo(block_info_output, block_info_inputs)
extra_argpairs.append((block_info, out_ind))
extra_names.append("block_info")

if extra_argpairs:
if has_keyword(func, "block_info") or has_keyword(func, "block_id"):
# Rewrite the Blockwise layer. It would be nice to find a way to
# avoid doing it twice, but it's currently needed to determine
# out.chunks from the first pass. Since it constructs a Blockwise
# rather than an expanded graph, it shouldn't be too expensive.
# We don't need to provide any keys to BlockwiseDepDict, because the
# default for missing keys is to provide the block ID.
out = blockwise(
_pass_extra_kwargs,
_pass_block_info,
out_ind,
func,
None,
tuple(extra_names),
block_info,
None,
*concat(extra_argpairs),
BlockwiseDepDict({}, numblocks=out.numblocks),
out_ind,
*concat(argpairs),
name=out.name,
new_axes=new_axes,
Expand Down
11 changes: 6 additions & 5 deletions dask/tests/test_distributed.py
Expand Up @@ -455,16 +455,17 @@ async def test_map_blocks_block_info(c, s, a, b):
np = pytest.importorskip("numpy")

def func(x, y, block_info):
with open("blockinfo.txt", "a") as f:
print(block_info, file=f, flush=True)
return np.array([[block_info]], dtype=object)

a = da.ones((4,), chunks=2)
b = da.ones((3, 2), chunks=(1, 2))
# optimize_graph=False ensures that the _BlockInfo is serialized to the
out = da.map_blocks(func, a, b, dtype=object, chunks=((1, 1, 1), (1, 1)))
print(out)
# optimize_graph=False ensures that the _BlockInfoDep is serialized to the
# scheduler, rather than materialised on the client.
blocks = await c.compute(
da.map_blocks(func, a, b, meta=np.ndarray((), dtype=object), chunks=(1, 1)),
optimize_graph=False,
)
blocks = await c.compute(out, optimize_graph=False)
assert blocks.shape == (3, 2)
assert blocks[2, 1] == {
0: {
Expand Down

0 comments on commit bc2b235

Please sign in to comment.