-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Using jax.jit
inside a function decorated by jax.checkpoint
causes recompilation every time
#9661
Comments
Hey, since this is a big blocker for our project at the moment, I would be very thankful if some of the jax-internals experts could help with this issue? I tried to dig into the code why this happening but found nothing yet :( |
Thanks for pinging this and highlighting that it's a blocker. Sorry for not getting to it sooner! |
See google#9661 for discussion
This was a bit tricky! There is a problem with grad-of-remat-of-jit causing jit cache misses. These two changes are sufficient to fix it, and I believe both are necessary as well (though see below for alternatives):
The latter is necessary because when round-tripping through a jaxpr (i.e. basically doing eval-jaxpr-of-make-jaxpr, as we often do, including when we linearize) we need to make sure we get cache hits. But we were constructing new opaque callables (with hash/equality defined by object id) every time, meaning we never did get cache hits. In particular, consider this repro involving # linearize_repro.py
import sys
import jax
identity = jax.checkpoint(jax.jit(lambda x: 2 * x))
_, f_lin = jax.linearize(identity, 1.)
for _ in range(int(sys.argv[1])):
f_lin(1.) Running After #10034 the number of compiles becomes constant with However, while that fixed the issue for (TODO explain that the new version of TODO talk about the new tooling I'm going to add to explain automatically why recompiles happen |
See google#9661 for discussion
See google#9661 for discussion
@JanLuca is the blocker resolved? (At git HEAD I mean; we haven't done a pypi release with these changes yet.) |
@mattjj Yes, I got no OOM error for my test run with the git HEAD today :D Thank you very much for the help! |
@mattjj I just encountered an exception in a run of my code using the git HEAD. Do not know if this is related with the change:
|
I don't think that's related. If it's still happening, can you open a new issue with a repro? cc @pschuh |
Nevermind, df1c478 already fixed the problem for me, I was just on an outdated checkout during my test |
Using a jitted function inside a function decorated by
jax.checkpoint
causes a lot of extra compilations even if the arguments still have the same shape. Calculating the gradient for such a function causes a memory leak in the long rung since all the compiled jitted functions seem to be stored in the memory. This can be observed by the high memory footprint ofbackend_compile
which cannot be seen if the checkpointing is disabled.A self-consistent example would be:
Running the script with
env JAX_LOG_COMPILES=1
enabled one can observe:Comment out the checkpoint decorator leads to the wanted behavior:
The text was updated successfully, but these errors were encountered: