In [1]:
import numpy as np
import jax.numpy as jnp

In [2]:
from jax import jit, grad, jacfwd, jacrev

In [3]:
import multiprocess as mp

## Jamie's HZPT model

For this model, we are computing its Jacobian matrix with shape (50, 3).

### Analytic

In [4]:
r = np.logspace(-1,2)

def hzpt_g_analytic(r,A0,R,R1h):
    A_grad = - np.exp(-r/R) / (4*np.pi*r*R**2) *(1 - (R/R1h)**2 * np.exp(-(R+R1h)*r/(R*R1h)))
    R_grad = A0*(-(r- 2*R)/R**2 + 2*np.exp(-r*(1/R + 1/R1h))*r /R1h**2) * np.exp(-r/R) / (4*np.pi*r*R**2)
    R1h_grad = A0* np.exp(-r*(1/R + 1/R1h))* R**2 *(r-2*R1h)/R1h**4 * np.exp(-r/R) / (4*np.pi*r*R**2)
    return np.array([A_grad,R_grad,R1h_grad])

%timeit hzpt_g_analytic(r, 750.,26.,2.)

45.7 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


⬆️ single call, with timeit

In [5]:
%%time
for _ in range(int(1e5)):
    hzpt_g_analytic(r + _ * 1e-6, 750. + _ * 1e-6, 26. + _ * 1e-6, 2. + _ * 1e-6)

CPU times: user 4.77 s, sys: 8.37 ms, total: 4.78 s
Wall time: 4.78 s


⬆️ 1e5 calls, 1 local worker

In [6]:
%%time

def foo(x):
    hzpt_g_analytic(r + x, 750. + x, 26. + x, 2. + x)

with mp.Pool(4) as pool:
    pool.map(foo, np.arange(1e5) * 1e-6)

CPU times: user 1.67 s, sys: 42.7 ms, total: 1.72 s
Wall time: 2.21 s


⬆️ 1e5 calls, 4 parallel workers, not returning results (less overhead)

In [7]:
%%time

def foo(x):
    return hzpt_g_analytic(r + x, 750. + x, 26. + x, 2. + x)

with mp.Pool(4) as pool:
    pool.map(foo, np.arange(1e5) * 1e-6)

CPU times: user 2.31 s, sys: 241 ms, total: 2.55 s
Wall time: 3.72 s


⬆️ 1e5 calls, 4 parallel workers, returning results (more overhead)

### Jax (fwd mode)

In [4]:
r = jnp.logspace(-1,2)

@jit
def hzpt_f_jax(r,A0,R,R1h):
    F2 = jnp.exp(-r/R) / (4*jnp.pi*r*R**2)
    return -A0 * F2 * (1 - (R/R1h)**2 * jnp.exp(-(R+R1h)*r/(R*R1h)))



In [5]:
def hzpt_jacfwd_jax(r, A0, R, R1h):
    return jacfwd(hzpt_f_jax, argnums=(1, 2, 3))(r, A0, R, R1h)

hzpt_jacfwd_jax(r, 750., 26., 2.)
%timeit hzpt_jacfwd_jax(r, 750., 26., 2.)

3.59 ms ± 206 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


⬆️ no jit, single call, with timeit

In [6]:
@jit
def hzpt_jacfwd_jax(r, A0, R, R1h):
    return jacfwd(hzpt_f_jax, argnums=(1, 2, 3))(r, A0, R, R1h)

hzpt_jacfwd_jax(r, 750., 26., 2.)
%timeit hzpt_jacfwd_jax(r, 750., 26., 2.)

11.4 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


⬆️ jit, single call, with timeit

In [7]:
%%time
for _ in range(int(1e5)):
    hzpt_jacfwd_jax(r + _ * 1e-6, 750. + _ * 1e-6, 26. + _ * 1e-6, 2. + _ * 1e-6)

CPU times: user 1.87 s, sys: 0 ns, total: 1.87 s
Wall time: 1.87 s


⬆️ jit, 1e5 calls, 1 local worker

In [8]:
%%time

def foo(x):
    hzpt_jacfwd_jax(r + x, 750. + x, 26. + x, 2. + x)

with mp.Pool(4) as pool:
    pool.map(foo, np.arange(1e5) * 1e-6)

CPU times: user 1.8 s, sys: 85.9 ms, total: 1.89 s
Wall time: 2.06 s


⬆️ jit, 1e5 calls, 4 parallel workers, not returning results (less overhead)

In [9]:
%%time

def foo(x):
    return hzpt_jacfwd_jax(r + x, 750. + x, 26. + x, 2. + x)

with mp.Pool(4) as pool:
    pool.map(foo, np.arange(1e5) * 1e-6)

CPU times: user 2.7 s, sys: 137 ms, total: 2.84 s
Wall time: 5.75 s


⬆️ jit, 1e5 calls, 4 parallel workers, returning results (more overhead)

### Jax (rev mode)

In [4]:
r = jnp.logspace(-1,2)

@jit
def hzpt_f_jax(r,A0,R,R1h):
    F2 = jnp.exp(-r/R) / (4*jnp.pi*r*R**2)
    return -A0 * F2 * (1 - (R/R1h)**2 * jnp.exp(-(R+R1h)*r/(R*R1h)))



In [5]:
def hzpt_jacrev_jax(r, A0, R, R1h):
    return jacrev(hzpt_f_jax, argnums=(1, 2, 3))(r, A0, R, R1h)

hzpt_jacrev_jax(r, 750., 26., 2.)
%timeit hzpt_jacrev_jax(r, 750., 26., 2.)

4.69 ms ± 85.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


⬆️ no jit, single call, with timeit

In [6]:
@jit
def hzpt_jacrev_jax(r, A0, R, R1h):
    return jacrev(hzpt_f_jax, argnums=(1, 2, 3))(r, A0, R, R1h)

hzpt_jacrev_jax(r, 750., 26., 2.)
%timeit hzpt_jacrev_jax(r, 750., 26., 2.)

26.4 µs ± 441 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


⬆️ jit, single call, with timeit

In [7]:
%%time
for _ in range(int(1e5)):
    hzpt_jacrev_jax(r + _ * 1e-6, 750. + _ * 1e-6, 26. + _ * 1e-6, 2. + _ * 1e-6)

CPU times: user 3.74 s, sys: 7.2 ms, total: 3.75 s
Wall time: 3.76 s


⬆️ jit, 1e5 calls, 1 local worker

In [8]:
%%time

def foo(x):
    hzpt_jacrev_jax(r + x, 750. + x, 26. + x, 2. + x)

with mp.Pool(4) as pool:
    pool.map(foo, np.arange(1e5) * 1e-6)

CPU times: user 1.72 s, sys: 95.5 ms, total: 1.81 s
Wall time: 2.51 s


⬆️ jit, 1e5 calls, 4 parallel workers, not returning results (less overhead)

In [9]:
%%time

def foo(x):
    return hzpt_jacrev_jax(r + x, 750. + x, 26. + x, 2. + x)

with mp.Pool(4) as pool:
    pool.map(foo, np.arange(1e5) * 1e-6)

CPU times: user 2.69 s, sys: 160 ms, total: 2.85 s
Wall time: 6.78 s


⬆️ jit, 1e5 calls, 4 parallel workers, returning results (more overhead)