Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numba support for JAX #3923

Closed
ktattan opened this issue Apr 1, 2019 · 8 comments
Closed

Numba support for JAX #3923

ktattan opened this issue Apr 1, 2019 · 8 comments
Labels
question Notes an issue as a question stale Marker label for stale issues.

Comments

@ktattan
Copy link

ktattan commented Apr 1, 2019

Google has a great library, JAX, for automatic differentiation (AD). It seems JAX is undergoing heavy development and using its grad method is a lot easier than AD in other libraries (e.g. theano).

Given that a lot of algorithms (Deep Neural Nets, No-U-Turn Sampling etc.) rely on differentiation at their core, it would be great if numba could support JAX to achieve that incredible speed up already supported in numpy functions.

@seibert
Copy link
Contributor

seibert commented Apr 2, 2019

Given that JAX is also a JIT compiler, are there situations where it can compute the gradient, but you would not want to compile the resulting function with JAX itself?

@seibert
Copy link
Contributor

seibert commented Apr 2, 2019

Or are you looking for the reverse, where Numba has interop with JAX to use it to differentiate functions you are compiling with Numba?

@ktattan
Copy link
Author

ktattan commented Apr 2, 2019

I suppose the former. But ultimately whichever implementation is most general and benefits the community the most would be best.

What I've noticed is that it is hard to compile functions with numba njit when there is differentiation using JAX involved. Even if the function is broken into many sub functions that can individually be compiled with njit, the "main" function (the outer loop, say) cannot be compiled (at least not in nopython mode), and therefore the code doesn't achieve great speed ups.

I hope this makes sense. It's hard to generalize my use cases to the wider community.

@seibert
Copy link
Contributor

seibert commented Apr 2, 2019

I haven't looked at the details of how JAX differentiation works, but I know we took a close look at Tangent. The code that generates is very close to being Numba supportable, if we added support for the particular stack data structure that Tangent uses in its generated code. I suspect JAX has a similar issue.

@seibert
Copy link
Contributor

seibert commented Apr 2, 2019

Looking into JAX more, it seems that they delegate to Autograd for differentiation, so this request is more specifically a "Numba support for Autograd". :)

@stuartarchibald stuartarchibald added question Notes an issue as a question and removed needtriage labels Apr 11, 2019
@agoose77
Copy link

agoose77 commented Feb 12, 2020

This doesn't solve the issue of deeper numba/jax integration, but I wrote some code to leverage CustomCall as a means of invoking numba functions. I needed this to implement a spline interpolator inside of jax.jit, as tracing the pure-python function fails for certain numpy functions have not yet been implemented. In some cases you can try and generate the lax code yourself, but I was having trouble with this so I opted for a quicker solution.

As you can't use it to do grad, it's only useful for the limited scenario where you want to use a non-jax-jittable function inside of jax, and don't need (or can explicitly calculate) the derivatives.

PoC calling numba from jax.jit
PoC generating wrapper code for numba.cfunc

Note that this implementation involves an additional pointer-chase per call due to the wrapper.

@bhack
Copy link

bhack commented Apr 4, 2020

Check also the thread at https://llvm.discourse.group/t/numpy-scipy-op-set/768 (expecially last posts) /cc @sklam @DrTodd13

@github-actions
Copy link

github-actions bot commented Jun 8, 2021

This issue is marked as stale as it has had no activity in the past 30 days. Please close this issue if no further response or action is needed. Otherwise, please respond with any updates and confirm that this issue still needs to be addressed.

@github-actions github-actions bot added the stale Marker label for stale issues. label Jun 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Notes an issue as a question stale Marker label for stale issues.
Projects
None yet
Development

No branches or pull requests

5 participants