In [1]:
import sys
if not '/home/pop518504/git/gknet-benchmarks' in sys.path:
    sys.path.insert(0, '/home/pop518504/git/gknet-benchmarks')

from jax_md import space, energy, quantity

import jax.numpy as jnp
from jax import grad, random, jit, device_put
from jax.config import config
config.update("jax_enable_x64", True)
import numpy as np

# Asynchronous dispatch
- Without explicit intervention, `JAX` will asynchronously dispatch the computation to the GPU.
- Even without `jit`, the first call includes compilation cost to `XLA`.
- The second call uses the already compiled `XLA` code and is much faster.



In [2]:
A_np = np.random.rand(1000, 1000)

# X = random.uniform(random.PRNGKey(0), (1000, 1000))
%time A = device_put(A_np)  # measure JAX device transfer time

CPU times: user 169 ms, sys: 992 ms, total: 1.16 s
Wall time: 1.34 s


In [3]:
%time M = jnp.dot(A, A).block_until_ready()
%time M = jnp.dot(A, A).block_until_ready()

CPU times: user 196 ms, sys: 98.7 ms, total: 295 ms
Wall time: 506 ms
CPU times: user 605 µs, sys: 698 µs, total: 1.3 ms
Wall time: 1.07 ms


- No caching happens
- However, compiled code is dependent on shape and data type such that it can be easily fooled for caching mechanisms ("JAX re-runs the Python function when the type or shape of the argument changes").
- The first call with $M = A \cdot B$ uses the same already compiled code as the previous computation of $M = A \cdot A$ and thus achieves similar performance.


In [4]:
B_np = np.random.rand(1000, 1000)
B = device_put(B_np)

%time M = jnp.dot(A, B).block_until_ready()
%time M = jnp.dot(A, B).block_until_ready()

CPU times: user 1.7 ms, sys: 1.96 ms, total: 3.65 ms
Wall time: 2.5 ms
CPU times: user 1.24 ms, sys: 0 ns, total: 1.24 ms
Wall time: 877 µs


In [5]:
C_np = np.random.rand(999, 999)
C = device_put(C_np)

%time M = jnp.dot(C, C).block_until_ready()
%time M = jnp.dot(C, C).block_until_ready()

CPU times: user 18.6 ms, sys: 0 ns, total: 18.6 ms
Wall time: 16.5 ms
CPU times: user 1.12 ms, sys: 393 µs, total: 1.51 ms
Wall time: 887 µs
