In [None]:
import jax
jax.config.update("jax_enable_x64", True) # use float64 -- jax by default uses float32

In [None]:
def slow_f(x):
  return x * x + x * 2

In [None]:
x = jax.numpy.ones((5000, 5000))
fast_f = jax.jit(slow_f)

In [None]:
%timeit slow_f(x)
%timeit fast_f(x) 

In [None]:
# a 'simple' numpy to numpy function
def calculate_wbt_jax(tc, rh):
    tw = (
        tc * jax.numpy.arctan(0.151977 * jax.numpy.sqrt(rh + 8.313659))
        + jax.numpy.arctan(tc + rh)
        - jax.numpy.arctan(rh - 1.676331)
        + 0.00391838 * (rh) ** (3 / 2) * jax.numpy.arctan(0.023101 * rh)
        - 4.686035
    )
    return tw

In [None]:
jax_wbt = jax.jit(calculate_wbt_jax)

In [None]:
import numpy as np
size = 1000000
tc = np.random.uniform(low=5., high=25., size=size)
rh = np.random.uniform(low=60., high=99., size=size)

In [None]:
w = jax_wbt(tc,rh)
print(w.shape)
print(type(w))

In [None]:
# timings with memory created by numpy
%timeit calculate_wbt_jax(tc,rh)
%timeit jax_wbt(tc,rh)

In [None]:
# allocating the arrays with jax seems not to make difference ON CPU when you JIT compile (no copies involved)
# but likely to make difference on GPU

jtc = jax.numpy.ones((size, ))
jrh = jax.numpy.ones((size, ))

%timeit calculate_wbt_jax(jtc,jrh)
%timeit jax_wbt(jtc,jrh)


In [None]:
import numba
import numpy

# compare with numba

def calculate_wbt_numpy(tc, rh):
    tw = (
        tc * numpy.arctan(0.151977 * numpy.sqrt(rh + 8.313659))
        + numpy.arctan(tc + rh)
        - numpy.arctan(rh - 1.676331)
        + 0.00391838 * (rh) ** (3 / 2) * numpy.arctan(0.023101 * rh)
        - 4.686035
    )
    return tw

nb_wbt = numba.jit(nopython=True, nogil=True, parallel=False)(calculate_wbt_numpy)

In [None]:
# using numba brings little effect on this function since most operations rely already on numpy
# nevertheless the timings of numba are slower than jax compiled code
%timeit calculate_wbt_numpy(tc, rh)
%timeit nb_wbt(tc,rh)


In [None]:
jtk = jax.numpy.ones((size, ))
jmart = jax.numpy.ones((size, ))
jva = jax.numpy.ones((size, ))

%timeit calculate_bgt_jax(jtk,jmart,jva)
%timeit jax_bgt(jtk,jmart,jva)

In [None]:
b1 = calculate_bgt_jax(jtk,jmart,jva)
b2 = jax_bgt(jtk,jmart,jva)

assert b1.all() == b2.all()