-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JAX backend - autograd is deprecated #800
Comments
Giving it a quick try it seems that we are facing the same issue as we would face with pytorch and tf backends - the bumpy arrays become immutable, which makes lots of tests fail: Example:
Some other tests fail for precision reasons:
|
Hi everyone! I am interested in but not at all familiar with the internals of this project (I maintain a Python library for topological data analysis and have been lured to look into |
Hey @ulupo, good to see you here 🎉 Yes, we would really like to speed up the library, as its speed is its current main limitation. But I do not believe that anyone is actively working on it right now. Do you have a recommendation about which tool to use, numba versus jax, versus both, versus some others? - or am I right in understanding that you recommend numba? I have also seen this tweet discussing both https://twitter.com/MilesCranmer/status/1205663981022564353. Thank you so much for your insights! 🙏 |
I have never tried |
Thank you for the insights, which make a lot of sense. I think we will prioritize |
Autograd [1] is deprecated for JAX [2]. It would be great to add the JAX backend for deploying vector/matrix/tensor operations to GPUs/TPUs.
Quoting the autograd website:
[1] https://github.com/HIPS/autograd
[2] https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
The text was updated successfully, but these errors were encountered: