In [1]:
import jax
jax.config.update("jax_enable_x64", True) # use float64 -- jax by default uses float32

In [2]:
def slow_f(x):
  return x * x + x * 2

In [4]:
x = jax.numpy.ones((5000, 5000))
fast_f = jax.jit(slow_f)

In [5]:
%timeit slow_f(x)
%timeit fast_f(x) 

77.5 ms ± 387 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
24.8 ms ± 53.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
# a 'simple' numpy to numpy function
def calculate_wbt_jax(tc, rh):
    tw = (
        tc * jax.numpy.arctan(0.151977 * jax.numpy.sqrt(rh + 8.313659))
        + jax.numpy.arctan(tc + rh)
        - jax.numpy.arctan(rh - 1.676331)
        + 0.00391838 * (rh) ** (3 / 2) * jax.numpy.arctan(0.023101 * rh)
        - 4.686035
    )
    return tw

In [8]:
jax_wbt = jax.jit(calculate_wbt_jax)

In [9]:
import numpy as np
size = 1000000
tc = np.random.uniform(low=5., high=25., size=size)
rh = np.random.uniform(low=60., high=99., size=size)

In [10]:
w = jax_wbt(tc,rh)
print(w.shape)
print(type(w))

(1000000,)
<class 'jaxlib.xla_extension.ArrayImpl'>


In [13]:
# timings with memory created by numpy
%timeit calculate_wbt_jax(tc,rh)
%timeit jax_wbt(tc,rh)

33.3 ms ± 846 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
8.93 ms ± 10.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
# allocating the arrays with jax seems not to make difference ON CPU when you JIT compile (no copies involved)
# but likely to make difference on GPU

jtc = jax.numpy.ones((size, ))
jrh = jax.numpy.ones((size, ))

%timeit calculate_wbt_jax(jtc,jrh)
%timeit jax_wbt(jtc,jrh)


16.2 ms ± 29.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
8.69 ms ± 3.98 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [17]:
import numba
import numpy

# compare with numba

def calculate_wbt_numpy(tc, rh):
    tw = (
        tc * numpy.arctan(0.151977 * numpy.sqrt(rh + 8.313659))
        + numpy.arctan(tc + rh)
        - numpy.arctan(rh - 1.676331)
        + 0.00391838 * (rh) ** (3 / 2) * numpy.arctan(0.023101 * rh)
        - 4.686035
    )
    return tw

nb_wbt = numba.jit(nopython=True, nogil=True, parallel=False)(calculate_wbt_numpy)

In [19]:
# using numba brings little effect on this function since most operations rely already on numpy
# nevertheless the timings of numba are slower than jax compiled code
%timeit calculate_wbt_numpy(tc, rh)
%timeit nb_wbt(tc,rh)


43.6 ms ± 53.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
39.1 ms ± 28.5 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [21]:
# try Jax with more complex functions including (pure)functions that call (pure)functions 

def kelvin_to_celsius(tk):
    tc = tk - 273.15
    return tc

def calculate_bgt_jax(t_k, mrt, va):
    f = (1.1e8 * va**0.6) / (0.98 * 0.15**0.4)
    a = f / 2
    b = -f * t_k - mrt**4
    rt1 = 3 ** (1 / 3)
    rt2 = jax.numpy.sqrt(3) * jax.numpy.sqrt(27 * a**4 - 16 * b**3) + 9 * a**2
    rt3 = 2 * 2 ** (2 / 3) * b
    a = a.clip(min=0)
    bgt_quartic = -1 / 2 * jax.numpy.sqrt(
        rt3 / (rt1 * rt2 ** (1 / 3)) + (2 ** (1 / 3) * rt2 ** (1 / 3)) / 3 ** (2 / 3)
    ) + 1 / 2 * jax.numpy.sqrt(
        (4 * a)
        / jax.numpy.sqrt(
            rt3 / (rt1 * rt2 ** (1 / 3))
            + (2 ** (1 / 3) * rt2 ** (1 / 3)) / 3 ** (2 / 3)
        )
        - (2 ** (1 / 3) * rt2 ** (1 / 3)) / 3 ** (2 / 3)
        - rt3 / (rt1 * rt2 ** (1 / 3))
    )
    bgt_c = kelvin_to_celsius(bgt_quartic)
    return bgt_c

jax_bgt = jax.jit(calculate_bgt_jax)