Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using jax.jit inside a function decorated by jax.checkpoint causes recompilation every time #9661

Closed
JanLuca opened this issue Feb 22, 2022 · 8 comments · Fixed by #10037
Closed
Assignees
Labels
bug Something isn't working

Comments

@JanLuca
Copy link

JanLuca commented Feb 22, 2022

Using a jitted function inside a function decorated by jax.checkpoint causes a lot of extra compilations even if the arguments still have the same shape. Calculating the gradient for such a function causes a memory leak in the long rung since all the compiled jitted functions seem to be stored in the memory. This can be observed by the high memory footprint of backend_compile which cannot be seen if the checkpointing is disabled.

A self-consistent example would be:

import jax
import jax.numpy as jnp

@jax.jit
def f(a):
    return jnp.sum(a)

@jax.checkpoint
def g(a):
    return f(a)

arr = jnp.array([[1,2,3,4,5],[6,7,8,9,10]], dtype=float)

g_v_and_grad = jax.value_and_grad(g)

for i in range(3):
    working_arr = arr + i
    print(g_v_and_grad(working_arr))

Running the script with env JAX_LOG_COMPILES=1 enabled one can observe:

WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002334117889404297 sec
WARNING:absl:Finished tracing + transforming fn for jit in 0.0003993511199951172 sec
WARNING:absl:Compiling fn (139703279463296 for args (ShapedArray(float32[2,5]), ShapedArray(int32[], weak_type=True)).
WARNING:absl:Finished XLA compilation of fn in 0.04700160026550293 sec
WARNING:absl:Finished tracing + transforming f for jit in 0.0010411739349365234 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00015473365783691406 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209762752 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.04526233673095703 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.00016546249389648438 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014591217041015625 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209798976 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.00750732421875 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0011491775512695312 sec
WARNING:absl:Compiling backward_pass (139703209802560 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.041948556900024414 sec
(DeviceArray(55., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014543533325195312 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209800384 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007508516311645508 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.0001461505889892578 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209863232 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007668972015380859 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0005974769592285156 sec
WARNING:absl:Compiling backward_pass (139703209362624 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.005425214767456055 sec
(DeviceArray(65., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00014638900756835938 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209350720 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007513523101806641 sec
WARNING:absl:Finished tracing + transforming <unnamed wrapped function> for jit in 0.00015473365783691406 sec
WARNING:absl:Compiling <unnamed wrapped function> (139703209372160 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.007587909698486328 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0005350112915039062 sec
WARNING:absl:Compiling backward_pass (139703209370048 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.0054433345794677734 sec
(DeviceArray(75., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))

Comment out the checkpoint decorator leads to the wanted behavior:

WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0002498626708984375 sec
WARNING:absl:Finished tracing + transforming fn for jit in 0.00040721893310546875 sec
WARNING:absl:Compiling fn (140693235752000 for args (ShapedArray(float32[2,5]), ShapedArray(int32[], weak_type=True)).
WARNING:absl:Finished XLA compilation of fn in 0.04748940467834473 sec
WARNING:absl:Finished tracing + transforming f for jit in 0.0010097026824951172 sec
WARNING:absl:Compiling f (140692730754112 for args (ShapedArray(float32[2,5]),).
WARNING:absl:Finished XLA compilation of jvp(f) in 0.04457998275756836 sec
WARNING:absl:Finished tracing + transforming prim_fun for jit in 0.0001583099365234375 sec
WARNING:absl:Finished tracing + transforming backward_pass for jit in 0.0004944801330566406 sec
WARNING:absl:Compiling backward_pass (140692730730304 for args (ShapedArray(float32[]),).
WARNING:absl:Finished XLA compilation of transpose(jvp(f)) in 0.041858673095703125 sec
(DeviceArray(55., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
(DeviceArray(65., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
(DeviceArray(75., dtype=float32), DeviceArray([[1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32))
@JanLuca
Copy link
Author

JanLuca commented Mar 24, 2022

Hey, since this is a big blocker for our project at the moment, I would be very thankful if some of the jax-internals experts could help with this issue? I tried to dig into the code why this happening but found nothing yet :(

@mattjj
Copy link
Member

mattjj commented Mar 24, 2022

Thanks for pinging this and highlighting that it's a blocker. Sorry for not getting to it sooner!

@mattjj mattjj self-assigned this Mar 24, 2022
@mattjj mattjj added the bug Something isn't working label Mar 24, 2022
mattjj added a commit to mattjj/jax that referenced this issue Mar 25, 2022
@mattjj
Copy link
Member

mattjj commented Mar 25, 2022

This was a bit tricky! There is a problem with grad-of-remat-of-jit causing jit cache misses. These two changes are sufficient to fix it, and I believe both are necessary as well (though see below for alternatives):

  1. improve remat transpose caching (cf. #9661) #10037 cache tracing of (sub)calls when forming a jaxpr #9181 (specifically the trace_to_subjaxpr_dynamic_memoized function, been meaning to land this for a while...)
  2. improve caching of jax.remat #10034

The latter is necessary because when round-tripping through a jaxpr (i.e. basically doing eval-jaxpr-of-make-jaxpr, as we often do, including when we linearize) we need to make sure we get cache hits. But we were constructing new opaque callables (with hash/equality defined by object id) every time, meaning we never did get cache hits.

In particular, consider this repro involving jax.linearize:

# linearize_repro.py
import sys
import jax

identity = jax.checkpoint(jax.jit(lambda x: 2 * x))

_, f_lin = jax.linearize(identity, 1.)
for _ in range(int(sys.argv[1])):
  f_lin(1.)

Running env JAX_LOG_COMPILES=1 python linearize_repro.py $N would show a number of compilations that scale linearly with N. (I piped stderr into grep 'XLA compilation' | wc -l like a pro.)

After #10034 the number of compiles becomes constant with N.

However, while that fixed the issue for linearize, it wasn't quite enough for grad (which is basically linearize plus a transposition step). For that, we had to improve the remat transpose rule to support caching. That's what #10037 does.

(TODO explain that the new version of checkpoint in ad_checkpoint.checkpoint also needed separate work)

TODO talk about the new tooling I'm going to add to explain automatically why recompiles happen

mattjj added a commit to mattjj/jax that referenced this issue Mar 25, 2022
mattjj added a commit to mattjj/jax that referenced this issue Mar 25, 2022
mattjj added a commit to mattjj/jax that referenced this issue Mar 25, 2022
mattjj added a commit to mattjj/jax that referenced this issue Mar 25, 2022
@mattjj
Copy link
Member

mattjj commented Mar 28, 2022

@JanLuca is the blocker resolved? (At git HEAD I mean; we haven't done a pypi release with these changes yet.)

@JanLuca
Copy link
Author

JanLuca commented Mar 29, 2022

@mattjj Yes, I got no OOM error for my test run with the git HEAD today :D Thank you very much for the help!

@JanLuca
Copy link
Author

JanLuca commented Apr 4, 2022

@mattjj I just encountered an exception in a run of my code using the git HEAD. Do not know if this is related with the change:

Exception ignored in: functools.partial(<function weakref_lru_cache.<locals>.remove_key at 0x7ff8dfcee670>, (True, 'allow', None), (), ())
Traceback (most recent call last):
  File "/local_scratch/janluca/pypoetry/peps-ad-2ZmN7-1J-py3.9/lib/python3.9/site-packages/jax/_src/util.py", line 230, in remove_key
    del cache[(weak_arg, tctx, args, kwargs)]
KeyError: (<weakref at 0x7ff8a063c0e0; dead>, (True, 'allow', None), (), ())

@mattjj
Copy link
Member

mattjj commented Apr 6, 2022

I don't think that's related. If it's still happening, can you open a new issue with a repro? cc @pschuh

@JanLuca
Copy link
Author

JanLuca commented Apr 6, 2022

Nevermind, df1c478 already fixed the problem for me, I was just on an outdated checkout during my test

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants