Make map_blocks with block_info produce a Blockwise#5896
Make map_blocks with block_info produce a Blockwise#5896TomAugspurger merged 3 commits intodask:masterfrom
Conversation
|
I've marked this a draft PR because it's built on top of #5895. In the meantime if you want to look at what it's doing, ignore the first two commits. In this (rather contrived) example, import dask.array as da
import numpy as np
import timeit
def combine(a, b, block_id=None):
return a + b if block_id[0] == block_id[1] else np.zeros_like(a)
a = da.ones((1000, 100), chunks=(1, 1))
b = da.zeros((1000, 100), chunks=(1, 1))
print(timeit.timeit('da.map_blocks(combine, a, b, dtype=a.dtype)', number=10, globals=globals()))
c = da.map_blocks(combine, a, b, dtype=a.dtype)
print(timeit.timeit('c.compute()', number=1, globals=globals()))Gains with |
|
By the way, are there any issues around serializability of the closures returned by |
Everything will work, but serialization is a bit slower. In general we try to use raw top-level defined functions and pass through additional data as keywords when possible. The core |
Ok, I'll have a look next week to see if there is some way to achieve that - possibly by passing the original function as a keyword arg. |
|
I figured out how to avoid using closures. Once #5895 is merged (it's currently approved) I'll rebase against master. I'll also need to check the test failures, but I think it's just a formatting issue from black. |
Previously, using map_blocks with block_info or block_id would denature a Blockwise and do a substitution on every instantiation of the subgraph callable to inject the block info. Instead, create a dask array whose elements are the block infos, and pass this as an extra parameter into `dask.array.blockwise`. This makes the layer a Blockwise, and hence a candidate for `rewrite_blockwise`.
bf0a5c1 to
c3f97e2
Compare
|
I've rebased against master, so this is ready for review now. |
|
By the way, I haven't added any tests because this is just an optimisation, not a new feature, and I think the existing map_blocks tests are reasonably comprehensive. |
|
I apologize for the delay in reviewing this. There are relatively few people familiar with this part of the codebase, and it's a bit more involved to review.
It still depends on a non-Blockwise though, but I guess that's ok? rewrite_blockwise is comfortable treating that as a small input to the entire resulting output? |
|
We might test that |
I'm now also less certain that there is any benefit to the resulting graph structure - I should definitely write a test to check that rewrite_blockwise is actually able to do the optimisation I'm claiming. My initial plan had been to follow this up by supporting |
If there is no concatenation, set `concatenate` to None. map_blocks always passes True, which was preventing map_blocks layers from being fused with elementwise layers by `optimize_blockwise`.
|
So it turned out that Once that's fixed, this example becomes much faster, both to build the structures and to compute (the chunks are tiny to emphasize the overheads): #!/usr/bin/env python3
import time
import dask.array as da
from dask.blockwise import optimize_blockwise
import numpy as np
CHUNK_SIZE = 10
NCHUNKS = 10000
SIZE = CHUNK_SIZE * NCHUNKS
def combine(x, y, block_id):
return x + y
t0 = time.monotonic()
base = [da.full((SIZE,), i, dtype=np.int8, chunks=CHUNK_SIZE) for i in range(4)]
a = base[0] + base[1]
b = da.map_blocks(combine, a, base[2], dtype=np.int8)
c = b + base[3]
t1 = time.monotonic()
c.compute()
t2 = time.monotonic()
print(t1 - t0)
print(t2 - t1)On master: On this branch: |
Just to clarify, that's fixed on this branch (likely in bad97f0)? |
Yes. |
|
Thanks @bmerry! |
Previously, using map_blocks with block_info or block_id would denature
a Blockwise and do a substitution on every instantiation of the subgraph
callable to inject the block info.
Instead, create a dask array whose elements are the block infos, and
pass this as an extra parameter into
dask.array.blockwise. This makesthe layer a Blockwise, and hence a candidate for
rewrite_blockwise.black dask/flake8 dask