New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support overriding implementations of JAX functions within a scope #4117
Conversation
@shoyer This feature could potentially allow a large speedup of |
@juliuskunze I would be a little reluctant to support directly overriding |
Whoa, this is awesome!! This would resolve my problem in Haiku (linked above) and it seems very versatile too. What's the future of this pull request? |
I think I support this. We haven't made any progress on a unified state mechanism since the last time this was proposed, and it's better than external libraries directly monkeypatching JAX. Note that most libraries that provide state-aware versions of functions like |
Since drafting this, we actually have use-case for this in our own application code where My one thought is that we should probably make it clearer that it's experimental. If we don't think it's worth making a new experimental module, perhaps we should just include "experimental" in the name, e.g., |
373fb79
to
5ae5c9b
Compare
5ae5c9b
to
f3efc6f
Compare
f3efc6f
to
06987b4
Compare
This is ready for review now. |
Wow, I appreciate how fast the Jax team responds to issues! I'm just curious, would it be worth adding the |
Placing |
c5f9f86
to
314f302
Compare
Yes, absolutely! I think I got them all, but let me know if you notice any others that are missing. (It's also easy to extend this in the future, of course.) |
Tests are passing -- is anyone from the JAX team up for a review? |
After discussion with the whole JAX team, we decided not to go further with this PR at this time. It solves some real problems (writing library code that uses higher order functions that wants to be compatible with stateful transformations), but:
For now, the recommendation is to write your own, library specific implementations of higher order functions like |
The idea here is to support overriding the implementation of higher order JAX
functions within a limited scope, for libraries such as Haiku and Flax that
implement their own versions of this functions that support mutation.
Usage example:
Ultimately, it would be great to replace this with a unified interface for
mutable state in JAX, but this could be a convenient temporary measure.