Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overriding implementations of JAX functions within a scope #4117

Closed
wants to merge 3 commits into from

Conversation

shoyer
Copy link
Member

@shoyer shoyer commented Aug 21, 2020

The idea here is to support overriding the implementation of higher order JAX
functions within a limited scope, for libraries such as Haiku and Flax that
implement their own versions of this functions that support mutation.

Usage example:

with jax.override_context({'lax.scan': my_scan}):
  # all calls to lax.scan() are replaced by my_scan()
  ...

Ultimately, it would be great to replace this with a unified interface for
mutable state in JAX, but this could be a convenient temporary measure.

@google-cla google-cla bot added the cla: yes label Aug 21, 2020
@shoyer shoyer changed the title Support implementations of JAX functions within a context. Support overriding implementations of JAX functions within a context. Aug 21, 2020
@shoyer shoyer changed the title Support overriding implementations of JAX functions within a context. Support overriding implementations of JAX functions within a scope Aug 21, 2020
@juliuskunze
Copy link
Contributor

juliuskunze commented Aug 21, 2020

@shoyer This feature could potentially allow a large speedup of jnp functions running on the NumPy backend: numpy_eval()(jnp.<some_fun>) could be mapped directly to np.<some_fun>, instead of tracing jnp.<some_fun> and calling the NumPy implementations of the encountered lax primitives. It might make #3893 fast out-of-the-box, removing the need to keep optimized NumPy implementations of shape rules.

@shoyer
Copy link
Member Author

shoyer commented Aug 21, 2020

@juliuskunze I would be a little reluctant to support directly overriding jax.numpy functions. JAX works so well in part because transformations are written at the LAX primitive level, which is a much more uniform and rationalized interface than jax.numpy. That's the main reason why I suggested this interface as a temporary measure -- it would be much better (if possible) to figure out a uniform interface for threading variables holding mutable state through transformations.

jax/util.py Outdated Show resolved Hide resolved
jax/util.py Outdated Show resolved Hide resolved
jax/util.py Outdated Show resolved Hide resolved
@NeilGirdhar
Copy link
Contributor

Whoa, this is awesome!! This would resolve my problem in Haiku (linked above) and it seems very versatile too. What's the future of this pull request?

@jekbradbury
Copy link
Contributor

I think I support this. We haven't made any progress on a unified state mechanism since the last time this was proposed, and it's better than external libraries directly monkeypatching JAX. Note that most libraries that provide state-aware versions of functions like jit, scan, etc. add additional knobs/functionality to those functions (and not necessarily in a mutually compatible way), so this feature doesn't solve all aspects of the problem.

@shoyer
Copy link
Member Author

shoyer commented Feb 7, 2021

Since drafting this, we actually have use-case for this in our own application code where haiku.scan wasn't quite right and we wanted our own version.

My one thought is that we should probably make it clearer that it's experimental. If we don't think it's worth making a new experimental module, perhaps we should just include "experimental" in the name, e.g., jax.experimental_override_context? Or perhaps it can just live in the top-level jax.experimental namespace along with the enable_x64/disable_x64 context managers.

@shoyer shoyer force-pushed the override-context branch 2 times, most recently from 373fb79 to 5ae5c9b Compare February 7, 2021 20:27
@shoyer shoyer marked this pull request as ready for review February 7, 2021 20:27
@shoyer
Copy link
Member Author

shoyer commented Feb 7, 2021

This is ready for review now.

@NeilGirdhar
Copy link
Contributor

Wow, I appreciate how fast the Jax team responds to issues!

I'm just curious, would it be worth adding the eval_shape override to the list? Haiku seems to provide an override.

@froystig
Copy link
Member

froystig commented Feb 8, 2021

Placing override_context in experimental sounds good to me.

@shoyer
Copy link
Member Author

shoyer commented Feb 8, 2021

I'm just curious, would it be worth adding the eval_shape override to the list? Haiku seems to provide an override.

Yes, absolutely!

I think I got them all, but let me know if you notice any others that are missing. (It's also easy to extend this in the future, of course.)

@shoyer
Copy link
Member Author

shoyer commented Feb 10, 2021

Tests are passing -- is anyone from the JAX team up for a review?

@shoyer shoyer added the pull ready Ready for copybara import and testing label Feb 10, 2021
@shoyer
Copy link
Member Author

shoyer commented Feb 26, 2021

After discussion with the whole JAX team, we decided not to go further with this PR at this time. It solves some real problems (writing library code that uses higher order functions that wants to be compatible with stateful transformations), but:

  1. Clearly isn't the ultimate solution (which would be some universal system for managing state inside JAX)
  2. It is still something that is going to require application specific choices. For example, if you scan over a neural net layer, should the weights be scanned over or not?

For now, the recommendation is to write your own, library specific implementations of higher order functions like scan and your own scope for overrides, if necessary.

@shoyer shoyer closed this Feb 26, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants