Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JAX backend - autograd is deprecated #800

Open
johmathe opened this issue May 19, 2020 · 6 comments
Open

Add JAX backend - autograd is deprecated #800

johmathe opened this issue May 19, 2020 · 6 comments
Assignees

Comments

@johmathe
Copy link
Collaborator

Autograd [1] is deprecated for JAX [2]. It would be great to add the JAX backend for deploying vector/matrix/tensor operations to GPUs/TPUs.

Quoting the autograd website:

Note: Autograd is still being maintained but is no longer actively developed. The main developers (Dougal Maclaurin, David Duvenaud, Matt Johnson, and Jamie Townsend) are now working on JAX, with Dougal and Matt working on it full-time. JAX combines a new version of Autograd with extra features such as jit

[1] https://github.com/HIPS/autograd
[2] https://jax.readthedocs.io/en/latest/notebooks/quickstart.html

@johmathe
Copy link
Collaborator Author

Giving it a quick try it seems that we are facing the same issue as we would face with pytorch and tf backends - the bumpy arrays become immutable, which makes lots of tests fail:

Example:

======================================================================
ERROR: test_exp_and_dist_and_projection_to_tangent_space (tests.test_hypersphere.TestHypersphere)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/johmathe/geomstats/tests/test_hypersphere.py", line 430, in test_exp_and_dist_and_projection_to_tangent_space
    tangent_vec=tangent_vec, base_point=base_point)
  File "/Users/johmathe/geomstats/geomstats/vectorization.py", line 104, in wrapper
    result = function(*vect_args, **vect_kwargs)
  File "/Users/johmathe/geomstats/geomstats/geometry/hypersphere.py", line 478, in exp
    mask_0)
  File "/Users/johmathe/geomstats/geomstats/_backend/numpy/__init__.py", line 196, in assignment
    x_new[indices] = values
ValueError: assignment destination is read-only

Some other tests fail for precision reasons:

======================================================================
FAIL: test_exp_vectorization (tests.test_hyperbolic.TestHyperbolic)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/johmathe/geomstats/tests/test_hyperbolic.py", line 222, in test_exp_vectorization
    self.assertAllClose(result, expected)
  File "/Users/johmathe/geomstats/geomstats/tests.py", line 77, in assertAllClose
    return np.testing.assert_allclose(a, b, rtol=rtol, atol=atol)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 1533, in assert_allclose
    verbose=verbose, header=header, equal_nan=equal_nan)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/numpy/testing/_private/utils.py", line 846, in assert_array_compare
    raise AssertionError(msg)
AssertionError:
Not equal to tolerance rtol=1e-06, atol=1e-06

Mismatched elements: 4 / 12 (33.3%)
Max absolute difference: 0.04579544
Max relative difference: 0.00093243
 x: array([[ 1.987827,  0.795378,  1.192449,  0.947046],
       [49.19424 ,  9.551132, 39.643112, 27.50042 ],
       [ 3.127928,  2.036517, -0.146305,  2.148284]], dtype=float32)
 y: array([[ 1.987827,  0.795378,  1.192449,  0.947046],
       [49.148445,  9.542269, 39.606182, 27.474813],
       [ 3.127927,  2.036517, -0.146305,  2.148283]], dtype=float32)

@ulupo
Copy link

ulupo commented Feb 5, 2021

Hi everyone! I am interested in but not at all familiar with the internals of this project (I maintain a Python library for topological data analysis and have been lured to look into geomstats by @ninamiolane :)), so this question is very naive, but: does geomstats make use of just-in-time compilers such as the ones provided by numba or JAX? Is it in the works somehow? In the case of numba specifically, it should be feasible to add njit decorators in code currently handled by numpy. AFAIK things might be a little more complicated with JAX due to array immutability etc, but still largely doable. The performance gains can be well worth the effort esp. when there are nested loops. I was recently able to speed up some code by a factor of almost 100 in this manner.

@ninamiolane
Copy link
Collaborator

Hey @ulupo, good to see you here 🎉

Yes, we would really like to speed up the library, as its speed is its current main limitation. But I do not believe that anyone is actively working on it right now.

Do you have a recommendation about which tool to use, numba versus jax, versus both, versus some others? - or am I right in understanding that you recommend numba? I have also seen this tweet discussing both https://twitter.com/MilesCranmer/status/1205663981022564353.

Thank you so much for your insights! 🙏

@ulupo
Copy link

ulupo commented Feb 6, 2021

I have never tried JAX but notice you already are planning to use it, see #801. It seems you need autodiff capabilities anyway, so perhaps it would make sense to start by seeing how much speedup can be obtained without introducing an extra dependency in numba. Of course, it might then make sense to benchmark numba-based approaches to see if the extra dependency is worth it. From what I hear, numba's LLVM compiler is slightly lower level than JAX's XLA, which should give numba an edge in some circumstances.

@ninamiolane
Copy link
Collaborator

Thank you for the insights, which make a lot of sense. I think we will prioritize numba indeed, but we don't have anyone actively working on it at the moment. Thus, if you want to give it a try, you are more than welcome!

@ninamiolane
Copy link
Collaborator

@ulupo there is an ongoing discussion in #920 to tackle this in the next days, feel free to jump in if you are interested!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants