-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Do JAX jit'd Python loops run faster than jit'd LAX loop constructs? #402
Comments
Thanks for bringing up this issue! Hope you don't mind that I tweaked the title to be a bit more explicit. My current best guess is this might be a potential performance improvement for XLA to make. Some thoughts below. When you use a regular Python loop construct under JAX's In the latter case, XLA gets strictly more information: it sees that there's a loop construct, and it has the option to unroll it when it's statically unrollable. (TODO for us: check that JAX is lowering the loop in a way that keeps static-unrollability transparent for XLA, since we do lift constants into the loop carry tuple.) So it should be able to generate code that is at least as good. From XLA's perspective, there are the usual tradeoffs with unrolling here: unrolling a loop could increase the code size (which might increase execution time if the code has to be loaded onto the device for each kernel launch), but enables some more optimizations and involves fewer branches. It may be that "branches" are expensive for the GPU backend because the loop condition might need to be pulled back to the host and checked there on each iteration, meaning more synchronizations than would be necessary in the unrolled case. Still, it seems that in principle XLA could do some unrolling or partial unrolling to mitigate that effect in cases like this one. Separate from questions of execution time, one of the main reasons to use loop constructs now is to reduce compile times: unrolling big loops can mean staging out a large program to XLA, and since XLA does a lot of optimizations, that can mean a lot of redundant work. So as a general rule of thumb, we can think of using loop constructs like |
@hawkinsp added another hypothesis to the list of things-for-us-to-investigate: it could be that the loop is causing XLA to choose bad memory layouts. |
Just to clarify: using a lax loop construct and jitting the whole loop should always be faster than using a python loop and only jitting the loop body (if there is a sufficient number of iterations and ignoring the loop body)?! |
It looks like for some reason XLA's while->for loop optimization is not firing, and this means that we spend a significant fraction of time on each loop iteration synchronizing the loop counter (?) from the GPU back to the host. I'm following up with the XLA folks. |
I didn't mean to say that; I meant to compare a Python loop under Providing a rolled loop is giving more information to the compiler, and so under the Sufficiently Smart Compiler hypothesis that should always be better. But manual loop unrolling can be helpful when you know better than the compiler's heuristics (e.g. based on the distribution of workloads you expect). I believe that's true even of modern C compilers. |
@mattjj sorry, my comment was rather a related question when I have a while loop that can‘t be unrolled (should not have put it here): is it faster to use a lax while_loop construct and jitting everything vs. sticking to unjitted python while loop with jitted loop body? I would assume so, but maybe lax while loop has still some downsides? |
@jonasrauber No problem! Hmm I'd expect the EDIT: I mean I'd expect them to be faster after this bug is resolved (which I expect to happen soon). |
Updates XLA to tensorflow/tensorflow@00afc7b. The new XLA release removes the use of protocol buffers from the XLA client. Fixes jax-ml#349. Add backward compatibility shims to jaxlib to allow older jax releases to still work on an up to date jaxlib. The new XLA release also incorporates a fix that avoids a host-device copy for every iteration of a `lax.fori_loop()` on GPU. Fixes jax-ml#402. Add a new jaxlib.__version__ field, change jax/jaxlib compatibility logic to check for it.
This is now fixed, but to get the fix, you either need to rebuild jaxlib from source or to wait until we push new binary wheels to PyPI (probably later this week). |
Please see the GIST Minimal repro where I tried to implement a recurrent computation (an echostate network) both in JAX and LAX. Using the lax.fori_loop resulted in a roughly 3x slow-down over jax jit'd naive python for loop.
This is not blocking me, but I was surprised by it and I cannot find anything I did wrong, though I may have misused the APIs in some way. JAX versions are listed as comment in the gist.
The text was updated successfully, but these errors were encountered: