back to jit: https://github.com/IanQS/numpy_to_jax/blob/main/exercises/exe_02_jit.ipynb

 Let's do a simple-ish task where we generate a matrix, 
 and everything less than 0.5 we take the square-root of, and anything greater than 0.5 we square. We finally multiply this matrix with itself

In [4]:
import numpy as np
import jax
import jax.numpy as jnp

input_arr_np = np.random.rand(1000, 1000)
input_arr_jax = jnp.array(input_arr_np)

def func_np(m):
    mask = m > 0.5
    m = np.where(mask, np.sqrt(m), m**2)
    return m @ m

def func_jax(m):
    mask = m > 0.5
    m = jnp.where(mask, jnp.sqrt(m), m**2)
    return m @ m

jitted_func = jax.jit(func_jax)

print("Numpy version:")
%timeit func_np(input_arr_np)

print("Jax version:")
%timeit func_jax(input_arr_jax)

print("Jitted version:")
%timeit jitted_func(input_arr_jax).block_until_ready()

print("Jitted version:")
%timeit jitted_func(input_arr_jax).block_until_ready()

Numpy version:
26.6 ms ± 628 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Jax version:
851 μs ± 76.8 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Jitted version:
1.91 ms ± 113 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Jitted version:
1.72 ms ± 313 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
