Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

factor convert_closure from ode to custom_derivatives #5244

Merged
merged 4 commits into from Dec 31, 2020

Conversation

froystig
Copy link
Member

Adds a closure-conversion utility to custom derivatives, and in turn to our API, useful when setting up custom derivatives for higher-order functions. This is based on a subroutine in ode, with minor modifications to generalize it for wider use.

Fixes #5222

@froystig froystig requested a review from mattjj December 22, 2020 18:26
@froystig froystig self-assigned this Dec 22, 2020
@google-cla google-cla bot added the cla: yes label Dec 22, 2020
@shoyer
Copy link
Member

shoyer commented Dec 22, 2020

This is great! One thing I might add to the docs is that you can get a pytree compatible function with tree_util.Partial(*closure_convert(objective_fn, x0)).

Another good use case for this helper would be lax.custom_root, which with closure_convert (replacing _initial_style_jaxpr) can now be implemented entirely with JAX public APIs.

Copy link
Member

@mattjj mattjj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Idea for a test mentioned below.

jax/custom_derivatives.py Outdated Show resolved Hide resolved
jax/custom_derivatives.py Show resolved Hide resolved
@mattjj
Copy link
Member

mattjj commented Dec 30, 2020

@shoyer I haven't thought through this, but I think that might not interact well with nondiff_argnums. Perhaps that can be sorted out but we might have to give it some thought and try out some examples.

@froystig froystig force-pushed the closure-convert branch 2 times, most recently from 3284db9 to c6e1cc2 Compare December 31, 2020 02:38
@froystig froystig added the pull ready Ready for copybara import and testing label Dec 31, 2020
@copybara-service copybara-service bot merged commit ba46e64 into master Dec 31, 2020
@froystig froystig deleted the closure-convert branch December 31, 2020 03:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

utility for closure-conversion in higher-order functions with custom derivatives
3 participants