Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do JAX jit'd Python loops run faster than jit'd LAX loop constructs? #402

Closed
sussillo opened this issue Feb 18, 2019 · 8 comments · Fixed by #452
Closed

Do JAX jit'd Python loops run faster than jit'd LAX loop constructs? #402

sussillo opened this issue Feb 18, 2019 · 8 comments · Fixed by #452
Labels
question Questions for the JAX team

Comments

@sussillo
Copy link

sussillo commented Feb 18, 2019

Please see the GIST Minimal repro where I tried to implement a recurrent computation (an echostate network) both in JAX and LAX. Using the lax.fori_loop resulted in a roughly 3x slow-down over jax jit'd naive python for loop.

JAX speed
Params seed 100001
JAX run 0.0160 sec

LAX speed
LAX run 0.0418 sec

This is not blocking me, but I was surprised by it and I cannot find anything I did wrong, though I may have misused the APIs in some way. JAX versions are listed as comment in the gist.

@mattjj mattjj added the question Questions for the JAX team label Feb 18, 2019
@mattjj mattjj changed the title Do JAX jit'd loops run faster than LAX jit'd loops? Do JAX jit'd Python loops run faster than jit'd LAX loop constructs? Feb 18, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 18, 2019

Thanks for bringing up this issue! Hope you don't mind that I tweaked the title to be a bit more explicit.

My current best guess is this might be a potential performance improvement for XLA to make. Some thoughts below.

When you use a regular Python loop construct under JAX's @jit, JAX (like Autograd) doesn't even see the Python loop and instead just traces out the unrolled computation; as a consequence, the program that gets staged out to XLA is also unrolled. When you instead use a lax.while_loop (or lax.fori_loop), the JAX tracer sees that loop as a primitive, and stages out a loop construct in the XLA program.

In the latter case, XLA gets strictly more information: it sees that there's a loop construct, and it has the option to unroll it when it's statically unrollable. (TODO for us: check that JAX is lowering the loop in a way that keeps static-unrollability transparent for XLA, since we do lift constants into the loop carry tuple.) So it should be able to generate code that is at least as good.

From XLA's perspective, there are the usual tradeoffs with unrolling here: unrolling a loop could increase the code size (which might increase execution time if the code has to be loaded onto the device for each kernel launch), but enables some more optimizations and involves fewer branches. It may be that "branches" are expensive for the GPU backend because the loop condition might need to be pulled back to the host and checked there on each iteration, meaning more synchronizations than would be necessary in the unrolled case. Still, it seems that in principle XLA could do some unrolling or partial unrolling to mitigate that effect in cases like this one.

Separate from questions of execution time, one of the main reasons to use loop constructs now is to reduce compile times: unrolling big loops can mean staging out a large program to XLA, and since XLA does a lot of optimizations, that can mean a lot of redundant work.

So as a general rule of thumb, we can think of using loop constructs like lax.while_loop and lax.fori_loop as tools to reduce compile times, often by a huge amount, but that could sometimes result in reduced execution performance. In principle those reductions in execution performance could be minimal, but right now there are probably cases where XLA loops aren't nearly as fast to execute as the unrolled code.

@mattjj
Copy link
Collaborator

mattjj commented Feb 18, 2019

@hawkinsp added another hypothesis to the list of things-for-us-to-investigate: it could be that the loop is causing XLA to choose bad memory layouts.

@jonasrauber
Copy link
Contributor

jonasrauber commented Feb 19, 2019

Just to clarify: using a lax loop construct and jitting the whole loop should always be faster than using a python loop and only jitting the loop body (if there is a sufficient number of iterations and ignoring the loop body)?!
jitting the python loop (i.e. unrolling) is of course a whole different story and can result in much slower compile times

@hawkinsp
Copy link
Collaborator

It looks like for some reason XLA's while->for loop optimization is not firing, and this means that we spend a significant fraction of time on each loop iteration synchronizing the loop counter (?) from the GPU back to the host. I'm following up with the XLA folks.

@mattjj
Copy link
Collaborator

mattjj commented Feb 19, 2019

Just to clarify: using a lax loop construct and jitting the whole loop should always be faster than using a python loop and only jitting the loop body (if there is a sufficient number of iterations and ignoring the loop body)?!

I didn't mean to say that; I meant to compare a Python loop under @jit (which would be fully staged into the XLA computation in an unrolled form) with a primitive loop construct. I believe those are the cases in the OP. Neither of those is the second scenario you mentioned, where we're executing FLOPs in a loop in Python (with or without a jitted body).

Providing a rolled loop is giving more information to the compiler, and so under the Sufficiently Smart Compiler hypothesis that should always be better. But manual loop unrolling can be helpful when you know better than the compiler's heuristics (e.g. based on the distribution of workloads you expect). I believe that's true even of modern C compilers.

@jonasrauber
Copy link
Contributor

@mattjj sorry, my comment was rather a related question when I have a while loop that can‘t be unrolled (should not have put it here): is it faster to use a lax while_loop construct and jitting everything vs. sticking to unjitted python while loop with jitted loop body? I would assume so, but maybe lax while loop has still some downsides?

@mattjj
Copy link
Collaborator

mattjj commented Feb 22, 2019

@jonasrauber No problem!

Hmm I'd expect the lax.while_loop and lax.fori_loop constructs to be faster. The main downside is that they're not reverse-mode differentiable (though they could be made forward-mode differentiable; it's the partial-eval step that is hard for reverse mode).

EDIT: I mean I'd expect them to be faster after this bug is resolved (which I expect to happen soon).

hawkinsp added a commit to hawkinsp/jax that referenced this issue Feb 26, 2019
Updates XLA to tensorflow/tensorflow@00afc7b.

The new XLA release removes the use of protocol buffers from the XLA client. Fixes jax-ml#349.
Add backward compatibility shims to jaxlib to allow older jax releases to still work on an up to date jaxlib.

The new XLA release also incorporates a fix that avoids a host-device copy for every iteration of a `lax.fori_loop()` on GPU. Fixes jax-ml#402.

Add a new jaxlib.__version__ field, change jax/jaxlib compatibility logic to check for it.
@hawkinsp hawkinsp mentioned this issue Feb 26, 2019
@hawkinsp
Copy link
Collaborator

This is now fixed, but to get the fix, you either need to rebuild jaxlib from source or to wait until we push new binary wheels to PyPI (probably later this week).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Questions for the JAX team
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants