Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve reuse of temporaries with numpy #5933

Merged
merged 1 commit into from Mar 6, 2020

Conversation

bmerry
Copy link
Contributor

@bmerry bmerry commented Feb 20, 2020

numpy has an optimisation that allows operations to be done in-place if
an operand is a temporary (has no more references) e.g. in the
expression a + b + c, c will be added in-place to the result of a + b.

Tweak the signle-threaded scheduler to avoid creating unnecessary
references. To make this effective, also modify Blockwise to run
fuse on the subgraph before creating SubgraphCallable.

This reduces the total runtime of the following code from 7.3s to 5.7s
on my machine:

from pprint import pprint
import dask.array as da
from dask.blockwise import optimize_blockwise
from dask.base import visualize
from dask.array.optimization import optimize

a = da.ones(2000000000, chunks=10000000)
b = a + a + a
c = da.sum(b)
c.compute()
  • Tests added / passed
  • Passes black dask / flake8 dask

numpy has an optimisation that allows operations to be done in-place if
an operand is a temporary (has no more references) e.g. in the
expression `a + b + c`, `c` will be added in-place to the result of `a +
b`.

Tweak the signle-threaded scheduler to avoid creating unnecessary
references. To make this effective, also modify Blockwise to run
`fuse` on the subgraph before creating SubgraphCallable.

This reduces the total runtime of the following code from 7.3s to 5.7s
on my machine:

```python
from pprint import pprint
import dask.array as da
from dask.blockwise import optimize_blockwise
from dask.base import visualize
from dask.array.optimization import optimize

a = da.ones(2000000000, chunks=10000000)
b = a + a + a
c = da.sum(b)
c.compute()
```
@bmerry
Copy link
Contributor Author

bmerry commented Feb 20, 2020

I have a few questions:

  • I used fuse on the subgraph because it achieved the particular optimisation I needed, but should one run a more general optimisation and/or pass any of the optional arguments to fuse? Since a SubgraphCallable is going to be evaluated many times, it seems worth doing spending a little time optimising it, but potentially the optimisations that are suitable for the parallel schedulers aren't as useful for dask.get.

  • What's the most appropriate place to do the subgraph optimisation? I started off doing it in rewrite_blockwise, on the basis that transformations should be done in optimisation functions and if someone specifically asks for optimisation not to happen then we shouldn't mess with their graphs. But that caused the unit tests for rewrite_blockwise to fail because they check how the graph has been rewritten.

@bmerry bmerry requested a review from mrocklin February 20, 2020 11:15
@mrocklin
Copy link
Member

@jcrist you may be interested in this one

@bmerry
Copy link
Contributor Author

bmerry commented Mar 2, 2020

@jcrist any thoughts?

@quasiben quasiben added this to In progress in Core maintenance Mar 3, 2020
@TomAugspurger
Copy link
Member

It'd would also be nice to have an asv for this in https://github.com/dask/dask-benchmarks/pulls. I could easily see it being undone in a refactor adding another reference to a value.

basis that transformations should be done in optimisation functions and if someone specifically asks for optimisation not to happen then we shouldn't mess with their graphs.

That sounds right to me, too. Can you share the output from the failed test? It might make sense to adjust it.

@jcrist
Copy link
Member

jcrist commented Mar 3, 2020

Overall this seems sane to me. I agree with Tom that a benchmark here would be good.

What's the most appropriate place to do the subgraph optimisation?

The optimization should be done prior to passing to SubgraphCallable (as you've done here). When created from a blockwise operation, the same SubgraphCallable is used multiple times (amortizing the cost of optimization), but there are other places in the codebase where a SubgraphCallable is created that will only be used once (and thus the time to optimize is usually just extra cost).

@bmerry
Copy link
Contributor Author

bmerry commented Mar 4, 2020

If you want to look into the errors, here's a patch against this branch for my alternative implementation:

diff --git a/dask/blockwise.py b/dask/blockwise.py
index 82af9ae9..c574e96c 100644
--- a/dask/blockwise.py
+++ b/dask/blockwise.py
@@ -193,8 +193,7 @@ class Blockwise(Mapping):
             return self._cached_dict
         else:
             keys = tuple(map(blockwise_token, range(len(self.indices))))
-            dsk, _ = fuse(self.dsk, [self.output])
-            func = SubgraphCallable(dsk, self.output, keys)
+            func = SubgraphCallable(self.dsk, self.output, keys)
             self._cached_dict = make_blockwise_graph(
                 func,
                 self.output,
@@ -683,6 +682,7 @@ def rewrite_blockwise(inputs):
 
     sub = {blockwise_token(k): blockwise_token(v) for k, v in sub.items()}
     dsk = {k: subs(v, sub) for k, v in dsk.items()}
+    dsk, _ = fuse(dsk, [root])
 
     indices_check = {k for k, v in indices if v is not None}
     numblocks = toolz.merge([inp.numblocks for inp in inputs.values()])

Here's the first failure - there are several though. It should be possible to adjust the expected values to account for the optimisation, but I worry that it will make the test more fragile because the expected value will now depend on implementation details, both of what optimizations we choose to run and the implementation of those optimizations. I guess we can handle the latter by parametrizing with the unoptimized graph and optimizing it inside the test.

_______________________ test_rewrite[inputs2-expected2] ________________________

inputs = [Blockwise<(('a', ('i',)),) -> b>, Blockwise<(('b', ('j',)),) -> c>]
expected = ('c', 'j', {'b': (<function inc at 0x7facc1005b70>, '_0'), 'c': (<function inc at 0x7facc1005b70>, 'b')}, [('a', 'j')])

    @pytest.mark.parametrize(
        "inputs,expected",
        [
            # output name, output index, task, input indices
            [[(b, "i", {b: (inc, _0)}, [(a, "i")])], (b, "i", {b: (inc, _0)}, [(a, "i")])],
            [
                [
                    (b, "i", {b: (inc, _0)}, [(a, "i")]),
                    (c, "i", {c: (dec, _0)}, [(a, "i")]),
                    (d, "i", {d: (add, _0, _1, _2)}, [(a, "i"), (b, "i"), (c, "i")]),
                ],
                (d, "i", {b: (inc, _0), c: (dec, _0), d: (add, _0, b, c)}, [(a, "i")]),
            ],
            [
                [
                    (b, "i", {b: (inc, _0)}, [(a, "i")]),
                    (c, "j", {c: (inc, _0)}, [(b, "j")]),
                ],
                (c, "j", {b: (inc, _0), c: (inc, b)}, [(a, "j")]),
            ],
            [
                [
                    (b, "i", {b: (sum, _0)}, [(a, "ij")]),
                    (c, "k", {c: (inc, _0)}, [(b, "k")]),
                ],
                (c, "k", {b: (sum, _0), c: (inc, b)}, [(a, "kA")]),
            ],
            [
                [
                    (c, "i", {c: (inc, _0)}, [(a, "i")]),
                    (d, "i", {d: (inc, _0)}, [(b, "i")]),
                    (g, "ij", {g: (add, _0, _1)}, [(c, "i"), (d, "j")]),
                ],
                (
                    g,
                    "ij",
                    {g: (add, c, d), c: (inc, _0), d: (inc, _1)},
                    [(a, "i"), (b, "j")],
                ),
            ],
            [
                [
                    (b, "ji", {b: (np.transpose, _0)}, [(a, "ij")]),
                    (c, "ij", {c: (add, _0, _1)}, [(a, "ij"), (b, "ij")]),
                ],
                (c, "ij", {c: (add, _0, b), b: (np.transpose, _1)}, [(a, "ij"), (a, "ji")]),
            ],
            [
                [
                    (c, "i", {c: (add, _0, _1)}, [(a, "i"), (b, "i")]),
                    (d, "i", {d: (inc, _0)}, [(c, "i")]),
                ],
                (d, "i", {d: (inc, c), c: (add, _0, _1)}, [(a, "i"), (b, "i")]),
            ],
            [
                [
                    (b, "ij", {b: (np.transpose, _0)}, [(a, "ji")]),
                    (d, "ij", {d: (np.dot, _0, _1)}, [(b, "ik"), (c, "kj")]),
                ],
                (
                    d,
                    "ij",
                    {d: (np.dot, b, _0), b: (np.transpose, _1)},
                    [(c, "kj"), (a, "ki")],
                ),
            ],
            [
                [
                    (c, "i", {c: (add, _0, _1)}, [(a, "i"), (b, "i")]),
                    (f, "i", {f: (add, _0, _1)}, [(d, "i"), (e, "i")]),
                    (g, "i", {g: (add, _0, _1)}, [(c, "i"), (f, "i")]),
                ],
                (
                    g,
                    "i",
                    {g: (add, c, f), f: (add, _2, _3), c: (add, _0, _1)},
                    [(a, i), (b, i), (d, i), (e, i)],
                ),
            ],
            [
                [
                    (c, "i", {c: (add, _0, _1)}, [(a, "i"), (b, "i")]),
                    (f, "i", {f: (add, _0, _1)}, [(a, "i"), (e, "i")]),
                    (g, "i", {g: (add, _0, _1)}, [(c, "i"), (f, "i")]),
                ],
                (
                    g,
                    "i",
                    {g: (add, c, f), f: (add, _0, _2), c: (add, _0, _1)},
                    [(a, "i"), (b, "i"), (e, "i")],
                ),
            ],
            [
                [
                    (b, "i", {b: (sum, _0)}, [(a, "ij")]),
                    (c, "i", {c: (inc, _0)}, [(b, "i")]),
                ],
                (c, "i", {c: (inc, b), b: (sum, _0)}, [(a, "iA")]),
            ],
            [
                [
                    (c, "i", {c: (inc, _0)}, [(b, "i")]),
                    (d, "i", {d: (add, _0, _1, _2)}, [(a, "i"), (b, "i"), (c, "i")]),
                ],
                (d, "i", {d: (add, _0, _1, c), c: (inc, _1)}, [(a, "i"), (b, "i")]),
            ],
            # Include literals
            [
                [(b, "i", {b: (add, _0, _1)}, [(a, "i"), (123, None)])],
                (b, "i", {b: (add, _0, _1)}, [(a, "i"), (123, None)]),
            ],
            [
                [
                    (b, "i", {b: (add, _0, _1)}, [(a, "i"), (123, None)]),
                    (c, "j", {c: (add, _0, _1)}, [(b, "j"), (456, None)]),
                ],
                (
                    c,
                    "j",
                    {b: (add, _1, _2), c: (add, b, _0)},
                    [(456, None), (a, "j"), (123, None)],
                ),
            ],
            # Literals that compare equal (e.g. 0 and False) aren't deduplicated
            [
                [
                    (b, "i", {b: (add, _0, _1)}, [(a, "i"), (0, None)]),
                    (c, "j", {c: (add, _0, _1)}, [(b, "j"), (False, None)]),
                ],
                (
                    c,
                    "j",
                    {b: (add, _1, _2), c: (add, b, _0)},
                    [(False, None), (a, "j"), (0, None)],
                ),
            ],
            # Literals are deduplicated
            [
                [
                    (b, "i", {b: (add, _0, _1)}, [(a, "i"), (123, None)]),
                    (c, "j", {c: (add, _0, _1)}, [(b, "j"), (123, None)]),
                ],
                (c, "j", {b: (add, _1, _0), c: (add, b, _0)}, [(123, None), (a, "j")]),
            ],
        ],
    )
    def test_rewrite(inputs, expected):
        inputs = [
            Blockwise(
                *inp, numblocks={k: (1,) * len(v) for k, v in inp[-1] if v is not None}
            )
            for inp in inputs
        ]
        result = rewrite_blockwise(inputs)
        result2 = (
            result.output,
            "".join(result.output_indices),
            result.dsk,
            [
                (name, "".join(ind) if ind is not None else ind)
                for name, ind in result.indices
            ],
        )
>       assert result2 == expected
E       AssertionError: assert ('c', 'j', {'... [('a', 'j')]) == ('c', 'j', {'... [('a', 'j')])
E         At index 2 diff: {'c': 'b-c', 'b-c': (<function inc at 0x7facc1005b70>, (<function inc at 0x7facc1005b70>, '_0'))} != {'b': (<function inc at 0x7facc1005b70>, '_0'), 'c': (<function inc at 0x7facc1005b70>, 'b')}
E         Use -v to get the full diff

dask/array/tests/test_atop.py:184: AssertionError

TomAugspurger added a commit to TomAugspurger/dask-benchmarks that referenced this pull request Mar 4, 2020
@TomAugspurger
Copy link
Member

Added a benchmark for this in TomAugspurger/dask-benchmarks@9daf334.

Thanks @bmerry!

@TomAugspurger TomAugspurger merged commit 10db6ba into dask:master Mar 6, 2020
@TomAugspurger TomAugspurger moved this from In progress to Done in Core maintenance Mar 6, 2020
@bmerry bmerry deleted the numpy-temporary-reuse branch February 25, 2021 12:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Development

Successfully merging this pull request may close these issues.

None yet

4 participants