From: https://gitlab.inria.fr/jls/geomstats-diffauto

## Automatic Differentiation in Geomstats: Autograd

## Moving from Autograd to Jax: Pros

Autograd:
- no longer under active development,
- tends to be too slow for medium to large-scale experiments,
- development for running Autograd on GPUs was never completed, and therefore training is limited by the execution time of native NumPy code.

Jax?
- = successor to the Autograd library
- = combines hardware acceleration and automatic differentiation with XLA, compiled instructions for faster linear algebra methods, often with improvements to memory usage as well. 
- function transformation jit for just-in-time compilation,
- vmap and pmap for vectorization and parallelization, 
- run models on a GPU (or TPU) if available.

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

Jax showcases convincing computation times when training neural networks:
<img src="figures/jax.png" width="400"/>


## Moving from Autograd to Jax: Cons

**Issue 1**: 
https://stackoverflow.com/questions/64517793/why-is-this-function-slower-in-jax-vs-numpy


In [None]:
def testFunction_numpy_v2(C, Mi, C_new, Mi_new):
    Wg_new = np.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new, Mi_new, Wg_new, Cg_new

@jax.jit
def testFunction_JAX_v2(C, Mi, C_new, Mi_new):
    Wg_new = jnp.linalg.solve(C_new, Mi_new)
    Cg_new = -0.5 * (Mi_new.conj() * Wg_new).sum(0)
    return C_new, Mi_new, Wg_new, Cg_new

%timeit testFunction_numpy_v2(C, Mi, C_new, Mi_new)
# 1000 loops, best of 3: 1.11 ms per loop
%timeit testFunction_JAX_v2(C_jax, Mi_jax, C_new_jax, Mi_new_jax)
# 1000 loops, best of 3: 1.35 ms per loop

Jax can slow down numpy computations for simple functions:
- JAX/numpy generate effectively the same short series of BLAS/LAPACK calls executed on CPUs
- Not much room for improvement over numpy: with such small arrays JAX's overhead is apparent.

**Issue 2**: 
https://gist.github.com/fehiepsi/00915b4a926f63d665a997de236fad80
Illustration of Jax overhead:
- Numpy has really low operation dispatch overheads 
    - years of expert engineering
- JAX's operation dispatch overheads are high
    - not for any design reason but because it hasn't yet been engineered to the same point.

Dispatch costs:
- are independent of the size of the arrays being operated on. 
- If dispatching a lot of small operations, JAX's dispatch overhead is large. 
- But if scaling up the data being operated on, those fixed dispatch costs aren't as big.

Another way to crush dispatch overheads is to use @jit. When you call an @jit function, you only pay the dispatch cost once, no matter how many jax.numpy functions are called inside that @jit function. Plus, XLA will end-to-end optimize the whole function, including fusing operations together and optimizing memory layouts and use.

# Are computations in Geomstats favorable to Jax?

In [2]:
!pip install jax jaxlib



# Dot from np arrays

In [3]:
from timeit import timeit

import autograd.numpy as anp
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd

In [4]:
small_sizes = [5, 10, 50]
sizes = small_sizes + [100, 1000, 10000]

In [5]:
timings = np.zeros((4, len(sizes)))
                   
for i_size, size in enumerate(sizes):
    print(f'Size {size}...')
    c = np.random.normal(size=(size,))
    d = np.random.normal(size=(size,))
    timings[0, i_size] = timeit(
        lambda: anp.dot(c, d), number = 50000)
    timings[1, i_size] = timeit(
        lambda: jnp.dot(c, d), number = 50000)
    timings[2, i_size] = timeit(
        lambda: jax.jit(jnp.dot)(c, d), number = 50000)
    timings[3, i_size] = timeit(
        lambda: np.dot(c, d), number = 50000)



Size 5...
Size 10...
Size 50...
Size 100...
Size 1000...
Size 10000...


In [6]:
timings_df = pd.DataFrame({
    "function": "dot from np.arrays",
    "size": sizes,
    "autograd": timings[0, :],
    "jax": timings[1, :],
    "jax jit": timings[2, :],
    "numpy": timings[3, :],
})

display(timings_df)

Unnamed: 0,function,size,autograd,jax,jax jit,numpy
0,dot from np.arrays,5,0.096677,15.522388,13.001843,0.040756
1,dot from np.arrays,10,0.092427,15.271843,13.273408,0.042178
2,dot from np.arrays,50,0.099572,15.328912,12.941105,0.041351
3,dot from np.arrays,100,0.082619,15.234657,13.249006,0.043444
4,dot from np.arrays,1000,0.101783,15.883423,13.225798,0.048018
5,dot from np.arrays,10000,0.181295,15.836179,13.630156,0.130876


Computation time varies little with the size:

--> Most of the computation time comes from dispatch overhead.

# Dot from jax arrays

In [7]:
timings2 = np.zeros((4, len(sizes)))
key0 = jax.random.PRNGKey(0)
key1 = jax.random.PRNGKey(1)
                   
for i_size, size in enumerate(sizes):
    print(f'Size {size}...')
    c = jax.random.normal(key0, shape=(size, ))
    d = jax.random.normal(key1, shape=(size, ))
    timings2[0, i_size] = timeit(
        lambda: anp.dot(c, d), number = 50000)
    timings2[1, i_size] = timeit(
        lambda: jnp.dot(c, d), number = 50000)
    timings2[2, i_size] = timeit(
        lambda: jax.jit(jnp.dot)(c, d), number = 50000)
    timings2[3, i_size] = timeit(
        lambda: np.dot(c, d), number = 50000)

Size 5...
Size 10...
Size 50...
Size 100...
Size 1000...
Size 10000...


In [8]:
timings_df2 = pd.DataFrame({
    "function": "dot from jax arrays",
    "size": sizes,
    "autograd": timings2[0, :],
    "jax": timings2[1, :],
    "jax jit": timings2[2, :],
    "numpy": timings2[3, :],
})

display(timings_df2)

Unnamed: 0,function,size,autograd,jax,jax jit,numpy
0,dot from jax arrays,5,0.294678,6.010035,8.816176,0.231162
1,dot from jax arrays,10,0.274839,6.105139,8.838437,0.235207
2,dot from jax arrays,50,0.288358,6.060228,8.909856,0.243972
3,dot from jax arrays,100,0.293221,6.126371,8.803957,0.237043
4,dot from jax arrays,1000,0.2988,6.107768,8.760931,0.24842
5,dot from jax arrays,10000,0.359755,5.963718,8.647516,0.301663


- Autograd and NumPy suffer from a x4-x6 slow down.
- Jax and Jax JIT benefit from a x3 speed-up.
- Jax and Jax Jit and still significantly slower.
- Computational time does not depend on the size.

# Einsum from np arrays

In [9]:
timings3 = np.zeros((4, len(small_sizes)))
                   
for i_size, size in enumerate(small_sizes):
    print(f'Size {size}...')
    c = np.random.normal(size=(size, size))
    d = np.random.normal(size=(size, size))
    timings3[0, i_size] = timeit(
        lambda: anp.einsum("ni,ni->ni", c, d), number = 50000)
    timings3[1, i_size] = timeit(
        lambda: jnp.einsum("ni,ni->ni", c, d), number = 50000)
# Note: Can't pass strings to jit decorated functions
#     timings3[2, i_size] = timeit(
#         lambda: jax.jit(jnp.einsum)("ni,ni->ni", c, d), number = 50000)
    timings3[3, i_size] = timeit(
        lambda: np.einsum("ni,ni->ni", c, d), number = 50000)

Size 5...
Size 10...
Size 50...


In [20]:
timings_df3 = pd.DataFrame({
    "function": "einsum for np arrays",
    "size": small_sizes,
    "autograd": timings3[0, :],
    "jax": timings3[1, :],
    "jax jit": np.nan,
    "numpy": timings3[3, :],
})

display(timings_df3)

Unnamed: 0,function,size,autograd,jax,jax jit,numpy
0,einsum for np arrays,5,0.178299,3.875827,,0.111117
1,einsum for np arrays,10,0.175369,3.83341,,0.112446
2,einsum for np arrays,50,0.26601,4.266139,,0.202558


# Einsum from Jax arrays

In [12]:
timings4 = np.zeros((4, len(small_sizes)))
key0 = jax.random.PRNGKey(0)
key1 = jax.random.PRNGKey(1)
                   
for i_size, size in enumerate(small_sizes):
    print(f'Size {size}...')
    c = jax.random.normal(key0, shape=(size, size))
    d = jax.random.normal(key1, shape=(size, size))
    timings4[0, i_size] = timeit(
        lambda: anp.einsum("ni,ni->ni", c, d), number = 50000)
    timings4[1, i_size] = timeit(
        lambda: jnp.einsum("ni,ni->ni", c, d), number = 50000)
# Note: Can't pass strings to jit decorated functions
#     timings4[2, i_size] = timeit(
#         lambda: jax.jit(jnp.dot)(c, d), number = 50000)
    timings4[3, i_size] = timeit(
        lambda: np.einsum("ni,ni->ni", c, d), number = 50000)

Size 5...
Size 10...
Size 50...


In [13]:
timings_df4 = pd.DataFrame({
    "function": "einsum for jax arrays",
    "size": small_sizes,
    "autograd": timings4[0, :],
    "jax": timings4[1, :],
    "jax jit": np.nan,
    "numpy": timings4[3, :],
})

display(timings_df4)

Unnamed: 0,function,size,autograd,jax,jax jit,numpy
0,einsum for jax arrays,5,0.36263,3.654455,,0.287643
1,einsum for jax arrays,10,0.378126,3.663458,,0.291371
2,einsum for jax arrays,50,0.420035,3.680894,,0.35448


# Eig on np arrays

In [14]:
timings5 = np.zeros((4, len(small_sizes)))
                   
for i_size, size in enumerate(small_sizes):
    print(f'Size {size}...')
    c = np.random.normal(size=(size, size))
    timings5[0, i_size] = timeit(
        lambda: anp.linalg.eig(c), number = 50000)
    timings5[1, i_size] = timeit(
        lambda: jnp.linalg.eig(c), number = 50000)
    timings5[2, i_size] = timeit(
        lambda: jax.jit(jnp.linalg.eig)(c), number = 50000)
    timings5[3, i_size] = timeit(
        lambda: np.linalg.eig(c), number = 50000)

Size 5...
Size 10...
Size 50...


In [18]:
timings_df5 = pd.DataFrame({
    "function": "eig for np arrays",
    "size": small_sizes,
    "autograd": timings5[0, :],
    "jax": timings5[1, :],
    "jax jit": timings5[2, :],
    "numpy": timings5[3, :],
})

display(timings_df5)

Unnamed: 0,function,size,autograd,jax,jax jit,numpy
0,eig for np arrays,5,1.459353,9.752146,11.233599,1.426106
1,eig for np arrays,10,2.049612,9.8199,11.134745,2.128796
2,eig for np arrays,50,42.312973,31.265696,32.883014,42.454529


# Eig on jax arrays

In [16]:
timings6 = np.zeros((4, len(small_sizes)))
key0 = jax.random.PRNGKey(0)
                   
for i_size, size in enumerate(small_sizes):
    print(f'Size {size}...')
    c = jax.random.normal(key0, shape=(size, size))
    timings6[0, i_size] = timeit(
        lambda: anp.linalg.eig(c), number = 50000)
    timings6[1, i_size] = timeit(
        lambda: jnp.linalg.eig(c), number = 50000)
    timings6[2, i_size] = timeit(
        lambda: jax.jit(jnp.linalg.eig)(c), number = 50000)
    timings6[3, i_size] = timeit(
        lambda: np.linalg.eig(c), number = 50000)

Size 5...
Size 10...
Size 50...


In [21]:
timings_df6 = pd.DataFrame({
    "function": "eig for jax arrays",
    "size": small_sizes,
    "autograd": timings6[0, :],
    "jax": timings6[1, :],
    "jax jit": timings6[2, :],
    "numpy": timings6[3, :],
})

display(timings_df6)

Unnamed: 0,function,size,autograd,jax,jax jit,numpy
0,eig for jax arrays,5,1.696682,6.355492,8.260189,1.585099
1,eig for jax arrays,10,2.47831,6.541062,8.399578,2.30441
2,eig for jax arrays,50,36.830241,29.331656,28.888744,37.386768


# A More comprehensive series of tests

https://gitlab.inria.fr/jls/geomstats-diffauto

on matrices operations (np.dot, np.linalg.svg, np.linalg.eig) jax is 4 times quicker than numpy and autograd.numpy.


JL: we may raise an issue but I guess the answer we will receive is that is is a well known problem.. There are some guidelines to follow then using jax. It is quite difficult to dig into the code to verify that all requirements are met…
eg:  https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html
https://jax.readthedocs.io/en/latest/faq.html#benchmarking-jax-code
https://github.com/google/jax/issues/427
that I understand is that jax.numpy is a rewriting of numpy, API compliant to numpy (most of the time, not 100% garantied)…

# Conclusion