Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix custom_jvp/vjp closure issues, and nondiff_argnums too! #4008

Closed
wants to merge 1 commit into from

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Aug 10, 2020

fixes #4566, fixes #4521, fixes #2912, fixes #3822, fixes #4173, fixes #2520, fixes #3808

If you survey those issues, you'll come to the conclusion that jax.custom_jvp and jax.custom_vjp don't work with closing over JAX Tracers, and moreover that nondiff_argnums can be just as problematic. Indeed, nondiff_argnums was effectively implemented in terms of lexical closure. The original plan was not to make custom_jvp/custom_vjp work with lexical closure, but I never documented that, and the error messages were terrible!

This PR makes custom_jvp and custom_vjp work with closed-over Tracers. Woo! That is, now custom_jvp and custom_vjp functions and rules can close over Tracers to our hearts' content. For all non-autodiff transformations, things will Just Work. For autodiff transformations, we'll get a clear error message about why we can't differentiate with respect to values over which a custom_jvp or custom_vjp closes:

Detected differentiation of a custom_jvp function with respect to a closed-over value. That isn't supported because the custom JVP rule only specifies how to differentiate the custom_jvp function with respect to explicit input parameters. Try passing the closed-over value into the custom_jvp function as an argument, and adapting the custom_jvp rule.

This PR accomplishes that goal by following through on the original ansatz of custom_jvp and custom_vjp, namely to make them work like core.call. That is, we do all the core.process_env_traces stuff properly.

TODO:

  • make pre-omnistaging version work again
  • fix jax2tf sublevels interaction
  • fix all google internal users
  • update custom derivatives notebook with new custom_vjp nondiff_argnums advice
  • write a pr message

@google-cla google-cla bot added the cla: yes label Aug 10, 2020
@mattjj mattjj force-pushed the custom-jvp-closure-fixes branch 2 times, most recently from b215d2f to 070beae Compare October 9, 2020 21:28
@mattjj mattjj changed the title custom_jvp/vjp closure issues, just experimenting for now fix custom_jvp/vjp closure issues, and nondiff_argnums too! Oct 9, 2020
@mattjj mattjj marked this pull request as ready for review October 10, 2020 00:05
@mattjj mattjj added the pull ready Ready for copybara import and testing label Oct 10, 2020
copybara-service bot pushed a commit to google-deepmind/dm-haiku that referenced this pull request Oct 14, 2020
Currently this method is inherited from the Trace base class in jax's core.py,
but I want to remove that default method because falling through to it has
caused at least two bugs (one in jax2tf that came up in
jax-ml/jax#4008, and another at HEAD described here:
jax-ml/jax#4566). This won't change Haiku's behavior
at all, but it'll make sure things keep working when the methods on the base
Trace class disappear.

PiperOrigin-RevId: 337088168
Change-Id: I4d5f85e6a2db5fc6348a2f17445afd23d29d9c9c
mattjj added a commit that referenced this pull request Oct 15, 2020
This change sets up some internal users so that we can then land #4008.
mattjj added a commit that referenced this pull request Oct 15, 2020
This caused a test failure when trying to land #4008.
copybara-service bot pushed a commit to tensorflow/probability that referenced this pull request Oct 16, 2020
HarvestTrace.process_custom_vjp_call.

Currently these fall back to the defaults on the superclass Trace, defined in
jax/core.py, but those defaults are going away after
jax-ml/jax#4008. Moreover, they probably didn't have
the behavior we want!

PiperOrigin-RevId: 337426371
copybara-service bot pushed a commit to google/trax that referenced this pull request Oct 16, 2020
jax-ml/jax#4008. In particular, don't use
nondiff_argnums for array-valued arguments to custom_vjp functions.

This change broke three test cases, which are now skipped. I'll work with lukaszkaiser@ to fix them!

PiperOrigin-RevId: 337338931
mattjj added a commit that referenced this pull request Oct 16, 2020
copybara-service bot pushed a commit to google/trax that referenced this pull request Oct 16, 2020
jax-ml/jax#4008. In particular, don't use
nondiff_argnums for array-valued arguments to custom_vjp functions.

This change broke three test cases, which are now skipped. I'll work with lukaszkaiser@ to fix them!

PiperOrigin-RevId: 337338931
copybara-service bot pushed a commit to google/trax that referenced this pull request Oct 16, 2020
jax-ml/jax#4008. In particular, don't use
nondiff_argnums for array-valued arguments to custom_vjp functions.

This change broke three test cases, which are now skipped. I'll work with lukaszkaiser@ to fix them!

PiperOrigin-RevId: 337451477
docs/custom_vjp_update.md Show resolved Hide resolved

def clip_gradient_bwd(res, g):
lo, hi = res
return (None, None, jnp.clip(g, lo, hi)) # return None for lo and hi
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's interesting point here: when an argument can sometimes be a tracer, but sometimes will just be a constant, and you're not actually tracing your code, you might want to query whether returning None is ok. The alternative is to always put in the work to compute the value, even though it might get thrown away later. But idk how important eager mode is for JAX, so it might not matter too much (it definitely did in PyTorch) 🤷

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I may not be following: returning None is always okay here, as it just is shorthand for "zeros like the input". That is, it works regardless of whether the corresponding input was a tracer, and works efficiently both in eager mode and in jit mode. Can you say more?

jax/custom_derivatives.py Show resolved Hide resolved
mattjj added a commit that referenced this pull request Oct 21, 2020
@mattjj
Copy link
Collaborator Author

mattjj commented Oct 21, 2020

Making the fixes from @apaszke in #4664.

@mattjj mattjj closed this Oct 21, 2020
@mattjj
Copy link
Collaborator Author

mattjj commented Oct 21, 2020

For posterity: this was merged, as 4a20eea, not sure why copybara got confused.

@mattjj mattjj deleted the custom-jvp-closure-fixes branch October 21, 2020 00:58
@NeilGirdhar
Copy link
Contributor

Woohoo, this is huge!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment