-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numba support for JAX #3923
Comments
Given that JAX is also a JIT compiler, are there situations where it can compute the gradient, but you would not want to compile the resulting function with JAX itself? |
Or are you looking for the reverse, where Numba has interop with JAX to use it to differentiate functions you are compiling with Numba? |
I suppose the former. But ultimately whichever implementation is most general and benefits the community the most would be best. What I've noticed is that it is hard to compile functions with numba njit when there is differentiation using JAX involved. Even if the function is broken into many sub functions that can individually be compiled with njit, the "main" function (the outer loop, say) cannot be compiled (at least not in nopython mode), and therefore the code doesn't achieve great speed ups. I hope this makes sense. It's hard to generalize my use cases to the wider community. |
I haven't looked at the details of how JAX differentiation works, but I know we took a close look at Tangent. The code that generates is very close to being Numba supportable, if we added support for the particular stack data structure that Tangent uses in its generated code. I suspect JAX has a similar issue. |
Looking into JAX more, it seems that they delegate to Autograd for differentiation, so this request is more specifically a "Numba support for Autograd". :) |
This doesn't solve the issue of deeper numba/jax integration, but I wrote some code to leverage As you can't use it to do grad, it's only useful for the limited scenario where you want to use a non-jax-jittable function inside of jax, and don't need (or can explicitly calculate) the derivatives. PoC calling numba from Note that this implementation involves an additional pointer-chase per call due to the wrapper. |
Check also the thread at https://llvm.discourse.group/t/numpy-scipy-op-set/768 (expecially last posts) /cc @sklam @DrTodd13 |
This issue is marked as stale as it has had no activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with any updates and confirm that this issue still needs to be addressed. |
Google has a great library, JAX, for automatic differentiation (AD). It seems JAX is undergoing heavy development and using its
grad
method is a lot easier than AD in other libraries (e.g. theano).Given that a lot of algorithms (Deep Neural Nets, No-U-Turn Sampling etc.) rely on differentiation at their core, it would be great if numba could support JAX to achieve that incredible speed up already supported in numpy functions.
The text was updated successfully, but these errors were encountered: