Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dask.array.jit or dask.array.vectorize #1946

Closed
mrocklin opened this issue Jan 30, 2017 · 27 comments
Closed

dask.array.jit or dask.array.vectorize #1946

mrocklin opened this issue Jan 30, 2017 · 27 comments

Comments

@mrocklin
Copy link
Member

When working with dask-glm I find myself interacting with functions like the following (where x is a dask.array):

def l2(x, t):
    return 1 / (1 + lamda * t) * x

def l1(x, t):
    return (absolute(x) > lamda * t) * (x - sign(x) * lamda * t)

These are costly in a few ways:

  1. They have decent overhead, because they repeatedly regenerate relatively large graphs
  2. On computation, even if we fuse, we create many intermediate copies of numpy arrays

So there are two part solutions that we could combine here:

  1. For any given dtype/shape/chunks signature, we could precompute a dask graph. When the same dtype/shape/chunks signature comes in we would stitch the new keys in at the right place, change around some tokenized values, and ship the result out without calling all of the dask.array code.
  2. We could numba.jit fused tasks

Using numba would actually be pretty valuable in some cases in dask-glm. This could be an optimization at the task graph level. I suspect that if we get good at recognizing recurring patterns and cache well that we could make this fast-ish. (add, _, (mul, _, _)) -> numba.jit(lambda x, y, z: x + y * z). We might also be able to back out patterns based on keys (not sure if this is safe)

cc @jcrist @eriknw @sklam @seibert @shoyer

@shoyer
Copy link
Member

shoyer commented Jan 30, 2017

+@crusaderky

I feel your pain. I'm not entirely sure what you're actually proposing, though.

Would it not suffice to simply insert numba.jit to decorate these functions in dask-glm? Doing this explicitly seems a lot easier than optimizing after the fact. If introducing dependencies is a concern (and it probably should be), you could pretty easily make a variation of numba.jit that is a no-op if numba is not installed.

@shoyer
Copy link
Member

shoyer commented Jan 30, 2017

For what it's worth, I would expect dask.array.vectorize to be a dask friendly version of numpy.vectorize that doesn't require numba. Numba support would also be handy but I would save that variant for another function name (e.g., numba_vectorize) or a keyword argument (numba=True).

@mrocklin
Copy link
Member Author

These functions are fed dask.arrays, so numba.jit on its own won't suffice. In the simple case we could easily do something like the following:

def linear_jit(func):
    jit_func = jit(func)
    def _(*args, **kwargs):
        return map_blocks(jit_func, *args, **kwargs)
    return _

@mrocklin
Copy link
Member Author

I think that concretely I'm proposing two things:

  1. An optional optimization that calls numba.jit on fused tasks that we run on all tasks in the graph after calling fuse
  2. A function decorator that watches for incoming input signatures, records the output graph, and when new inputs arrive with the same signature it short-cuts to use the precomputed graph (with some straightforward modifications) instead of rerunning all of the dask.array algorithms. In this case signature is (dtype, shape, chunks) for all dask.arrays and identity for all other inputs.

@shoyer
Copy link
Member

shoyer commented Jan 30, 2017

An optional optimization that calls numba.jit on fused tasks that we run on all tasks in the graph after calling fuse

Seems totally reasonable to me. Numba should remain an optional dependency, of course.

A function decorator that watches for incoming input signatures, records the output graph, and when new inputs arrive with the same signature it short-cuts to use the precomputed graph (with some straightforward modifications) instead of rerunning all of the dask.array algorithms. In this case signature is (dtype, shape, chunks) for all dask.arrays and identity for all other inputs.

No real objection here, either, though I would consider skipping straight to a programmatic dict for top (#1763).

@crusaderky
Copy link
Collaborator

crusaderky commented Jan 30, 2017

First, a word of warning.
I initially coded all the critical tasks of my xarray+dask project in numba. After several weeks chasing random interpreter segfaults, which only appeared when I reached production scale, I rewrote every last numba function in Cython. The result was uglier looking, and I got locked into a static dtype signature (not a problem for my specific case, as I do everything in float64 anyway), but the whole experience was A LOT smoother - predictable behaviour, easy to debug performance, and overall robustness.
I still think that numba has great potential, but overall my opinion for now is that it's still too green for production use - no offense if any of the developers are reading.

Second: in the functions mentioned in the original post, numba/cython optimization may produce some improvement - but in my experience from very similarly-looking functions, not enough to bother as long as your chunk size is large enough to prevent the GIL from kicking in. YMMV.

The big improvement comes when you simply wrap those expressions in a single pure numpy function (lamdba or not),which is then called atomically by dask. Hence the biggest problem as of today - IMHO, dask.array.map_blocks could really use some love as it's unusable e.g. whenever the inputs have mismatched dimensions.

@mrocklin
Copy link
Member Author

Thanks for the experience report @crusaderky . Some questions:

  1. Are you using numba.jit(nogil=True)
  2. Regarding your second point, how large were your chunks and how long were your computations?

The big improvement comes when you simply wrap those expressions in a single pure numpy function (lamdba or not),which is then called atomically by dask

Can you expand on this?

Hence the biggest problem as of today - IMHO, dask.array.map_blocks could really use some love as it's unusable e.g. whenever the inputs have mismatched dimensions.

Yeah, I actually have the reverse perspective, that map_blocks shouldn't attempt to be any more complex than it currently is. It has been accumulating features and corner cases for a while now. Would love for a different operation though if people have recommendations and time to build something. I tend to recommend atop or just dask.delayed to people in these cases.

@crusaderky
Copy link
Collaborator

  1. numba.jit(nogil=True, nopython=True, cache=True)
  2. An example of a formula I experimented with that won't get any benefit from cythonizing, the value of a floating rate bond under a monte carlo stress:
discounted_cashflow = cashflow * xarray.ufuncs.exp(-t * (credit_rate + interest_rate + spread)
value = discounted_cashflow.sum('time')

where all inputs are xarray variables

  • cashflow, credit_rate, interest_rate: dims=(time, scenario), shape=(50, 500k), chunks=(25, 100k)
  • t, spread: dims=(time, ), shape=(50, ), chunks=None (numpy-backed)

It's worth noting that in this formula I don't have any diamond dependencies; all dask-backed variables are the result of semi-independent branches of computation. This is unlike in your l1() function, where x appears 3 times.

The improvement you achieve by wrapping the whole thing in numpy is that you're reducing by a large factor the number of keys in your final dask dict - hence, all the non-parallelisable work needed to resolve the problem.

@mrocklin
Copy link
Member Author

Instead of Numba we could also generate more complex functions that use out parameters in numpy functions. This is probably more work to generate but would be lower-tech (good).

I'm still somewhat in favor of numba though. It seems easier to try out at least. This may be because I haven't been bitten by it in the same way that @crusaderky has. On the plus side, I suspect that numba devs would respond quickly to any issues we push upstream.

@mrocklin
Copy link
Member Author

Regardless of the numpy/numba choice, another benefit to rewriting fused tasks is serialization time. As we contract the number of tasks with optimizations like #1979 we reduce per-task overhead in all ways except serialization. The large compound tasks that we generate like (add, (sub, (exp, 'x-123'), 'y-123'), 1) are not significantly less expensive to serialize than they were before as several small tasks. If we can reduce these to a simple form of (f, 'x-123', 'y-123', 1) for complex f then it becomes much easier to cache the serialization of f and remove most of this cost.

@jcrist, this is starting to look like

  1. Generation of some Python code
  2. Application with Numba
  3. Passing over the graph with a small set of rewrite rules

Do you have any interest? This would align nicely with @eriknw 's current work on fusion.

@mrocklin
Copy link
Member Author

Hrm, I wonder if we can short-circuit some of the rewrite rule cost by using the fact that the key prefixes are likely to be the same.

@jcrist
Copy link
Member

jcrist commented Feb 14, 2017

I'm skeptical of generating python code (and using numba) here, and would rather create an object to "interpret" these tasks, using the out keywords for functions as available. A few reasons for this:

  • Numba compilation isn't threadsafe. This is what @crusaderky ran into, and it's an issue I ran into when working on datashader. We could hack around this though by forcing compilation, but it's something to watch out for.

  • Generated code doesn't serialize cheaply. Cloudpickle is required. In comparison, pickling several built-in functions wrapped in python objects is relatively cheap:

In [1]: from operator import add, mul

In [2]: task = (add, (mul, 'x', (add, 1, 'y')), 2)

In [3]: import cloudpickle

In [4]: len(cloudpickle.dumps(task))
Out[4]: 68

In [5]: s = """
   ...: def func(x, y):
   ...:     return ((1 + y) * x) + 2
   ...: """

In [6]: namespace = {'add': add, 'mul': mul}

In [7]: exec(s, namespace)

In [8]: func = namespace['func']

In [9]: len(cloudpickle.dumps(func))
Out[9]: 326
  • Distributed does cache "callables" inside dumps_task, but it doesn't care about if they're actually instances of FunctionType. This means that a callable object would work just as well. Since distributed always sends the bytes (and just caches the serialization cost), having a smaller serialized form is still important (IMO). Numba functions are even more expensive to serialize.

I'm not 100% against generating code here, but I'd prefer to try a simpler implementation first. Either way, this is something I'd be interested in working on, I just need to find the time :).

@mrocklin
Copy link
Member Author

Yeah, as long as things are of the form (expensive-function-thing, key, key, value) I'm happy. We would probably also need to do more benchmarking to verify that this was worthwhile (I'm happy to chat about benchmarks if you get into this.)

We also compress data between client and scheduler if it becomes large. I haven't seen moving large repetitive bytestrings become a bottleneck yet.

@pitrou
Copy link
Member

pitrou commented Feb 14, 2017

For the issue of serialization cost, you could serialize the function in its string form and then numba-compile it on the worker.

@mrocklin
Copy link
Member Author

@pitrou if you have any interest, the application I'm trying to improve is here: dask/dask-glm#26 . The first comment points to this single-threaded benchmark that is, I think, representative of the performance issue.

@crusaderky
Copy link
Collaborator

crusaderky commented Feb 15, 2017 via email

@jakirkham
Copy link
Member

A function decorator that watches for incoming input signatures, records the output graph, and when new inputs arrive with the same signature it short-cuts to use the precomputed graph (with some straightforward modifications) instead of rerunning all of the dask.array algorithms. In this case signature is (dtype, shape, chunks) for all dask.arrays and identity for all other inputs.

FWIU I really like this idea. Could you elaborate on it a bit more @mrocklin?

@mrocklin
Copy link
Member Author

mrocklin commented May 3, 2017

I could, but what I would say has probably already been said in this issue. So that I don't waste a lot of time writing down all of my thoughts on this perhaps you can focus things by asking what in particular you are curious about?

@jakirkham
Copy link
Member

Are you proposing caching graphs based on having the same inputs? Or are you proposing caching that is somehow able to stub in different inputs into the same graph structure? The former seems straightforward and potentially a nice add for a few cases using some preferred caching decorator. With the latter, I'm a little unclear on how this might work.

@mrocklin
Copy link
Member Author

mrocklin commented May 3, 2017

Or are you proposing caching that is somehow able to stub in different inputs into the same graph structure?

I am curious about this case.

With the latter, I'm a little unclear on how this might work.

Yes. Me too.

@mrocklin
Copy link
Member Author

mrocklin commented May 3, 2017

What are your thoughts on this topic @jakirkham ?

@jakirkham
Copy link
Member

Not sure. Haven't thought about it long enough to have any ideas.

That said, this does sound kind of like a LISP style problem. It might be worth taking a look at Hy and seeing if they are any libraries or tools that could be used or extended for this case.

@mrocklin
Copy link
Member Author

mrocklin commented May 3, 2017

I don't think that solving this problem will require that level of technology.

@mrocklin
Copy link
Member Author

mrocklin commented May 5, 2017

OK, so I think it would be useful to replace many of the keys in a graph with a new set of keys.

>>> dsk = {'x': 1, 'y': (inc, 'x'), 'z': (add, 'x', 'y')}
>>> swap = {'x': 'a', 'y': 'b', 'z': 'c'}
>>> replace(dsk, swap)
{'a': 1, 'b': (inc, 'a'), 'z': (add, 'a', 'b')}

I'm curious if/how we can do this quickly cc @eriknw @jcrist . A simple example motivating this operation follows:

Example

Lets say that I have dask object with the following graph and keys:

>>> original.dask
{'orig-1': 1,
 'orig-2': (inc, 'orig-1')}

>>> orig._keys()
['orig-2']

I want to run a function f on this dask object that produces a new dask object with a larger graph.

>>> out = f(orig)
>>> out.dask
{'orig-1': 1,
 'orig-2': (inc, 'orig-1'),
 'out-1', (inc, 2),
 'out-2', (add, 'out-1', 'orig-1')}

I also want to apply this same function onto a new object new with the following graph:

>>> new.dask
{'new-1': (g,)}

>>> new._keys()
['new-1']

However my function f takes a long time to run. I know separately that this new object and the original object are sufficiently similar so that calling f is unnecsssary, the new bit of the output graph will have identical structure.

So I want to do three things:

  1. Strip out the graph of orig. This is easy to do.
  2. Replace all instances of orig keys with new keys within the part of the graph that was added by f. I want to pass over all tasks with the following replacement mapping {'orig-2': 'new-1'}. In practice this mapping will have many thousands of elements.
  3. Slightly modify all keys in the graph to show that these are new tasks that will have distinct results. For example we might pass through with the following mapping {'out-1': 'out-1-1', 'out-2': 'out-2-2'}

Both parts two and three ask for a function that efficiently swaps out key names. I'm curious how fast this can be. The whole purpose of this exercise is to avoid graph construction and optimization costs.

@mcg1969
Copy link

mcg1969 commented Apr 17, 2018

I just saw this—and I love the concepts. Would be happy to help flesh them out. I happened upon this page by googling "dask monte carlo" by the way.

@jakirkham
Copy link
Member

It would be great to hear your thoughts on this, @mcg1969. :)

@mrocklin
Copy link
Member Author

This will likely be handled by the current effort on high level expression graphs. Closing.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants