-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix custom_jvp/vjp closure issues, and nondiff_argnums too! #4008
Conversation
b215d2f
to
070beae
Compare
4299e30
to
8f51e6a
Compare
Currently this method is inherited from the Trace base class in jax's core.py, but I want to remove that default method because falling through to it has caused at least two bugs (one in jax2tf that came up in jax-ml/jax#4008, and another at HEAD described here: jax-ml/jax#4566). This won't change Haiku's behavior at all, but it'll make sure things keep working when the methods on the base Trace class disappear. PiperOrigin-RevId: 337088168 Change-Id: I4d5f85e6a2db5fc6348a2f17445afd23d29d9c9c
This change sets up some internal users so that we can then land #4008.
This caused a test failure when trying to land #4008.
HarvestTrace.process_custom_vjp_call. Currently these fall back to the defaults on the superclass Trace, defined in jax/core.py, but those defaults are going away after jax-ml/jax#4008. Moreover, they probably didn't have the behavior we want! PiperOrigin-RevId: 337426371
jax-ml/jax#4008. In particular, don't use nondiff_argnums for array-valued arguments to custom_vjp functions. This change broke three test cases, which are now skipped. I'll work with lukaszkaiser@ to fix them! PiperOrigin-RevId: 337338931
jax-ml/jax#4008. In particular, don't use nondiff_argnums for array-valued arguments to custom_vjp functions. This change broke three test cases, which are now skipped. I'll work with lukaszkaiser@ to fix them! PiperOrigin-RevId: 337338931
jax-ml/jax#4008. In particular, don't use nondiff_argnums for array-valued arguments to custom_vjp functions. This change broke three test cases, which are now skipped. I'll work with lukaszkaiser@ to fix them! PiperOrigin-RevId: 337451477
af9d7a1
to
609f6f3
Compare
PiperOrigin-RevId: 337538591
|
||
def clip_gradient_bwd(res, g): | ||
lo, hi = res | ||
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's interesting point here: when an argument can sometimes be a tracer, but sometimes will just be a constant, and you're not actually tracing your code, you might want to query whether returning None
is ok. The alternative is to always put in the work to compute the value, even though it might get thrown away later. But idk how important eager mode is for JAX, so it might not matter too much (it definitely did in PyTorch) 🤷
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I may not be following: returning None is always okay here, as it just is shorthand for "zeros like the input". That is, it works regardless of whether the corresponding input was a tracer, and works efficiently both in eager mode and in jit mode. Can you say more?
For posterity: this was merged, as 4a20eea, not sure why copybara got confused. |
Woohoo, this is huge! |
fixes #4566, fixes #4521, fixes #2912, fixes #3822, fixes #4173, fixes #2520, fixes #3808
If you survey those issues, you'll come to the conclusion that
jax.custom_jvp
andjax.custom_vjp
don't work with closing over JAX Tracers, and moreover thatnondiff_argnums
can be just as problematic. Indeed,nondiff_argnums
was effectively implemented in terms of lexical closure. The original plan was not to makecustom_jvp
/custom_vjp
work with lexical closure, but I never documented that, and the error messages were terrible!This PR makes
custom_jvp
andcustom_vjp
work with closed-over Tracers. Woo! That is, nowcustom_jvp
andcustom_vjp
functions and rules can close over Tracers to our hearts' content. For all non-autodiff transformations, things will Just Work. For autodiff transformations, we'll get a clear error message about why we can't differentiate with respect to values over which a custom_jvp or custom_vjp closes:This PR accomplishes that goal by following through on the original ansatz of
custom_jvp
andcustom_vjp
, namely to make them work likecore.call
. That is, we do all thecore.process_env_traces
stuff properly.TODO: