### Common gotchas in JAX

- Pure functions: all the input data is passed through the function results, all the results are output through the function results.
- Array updates need to be performed by array.at[idx].set() instead of numpy overwrite functionality


In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random


key = random.key(0)
x = random.normal(key, (10,))
print(x)

size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]
364 ms ± 22.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
## vmap
import numpy as np

mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
  return jnp.dot(mat, v)

def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
# np.shape(naively_batched)
# jnp.stack([apply_matrix(v) for v in batched_x])

Naively batched
349 µs ± 3.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
@jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
7.95 µs ± 135 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [21]:
@jit
def vmap_batched_apply_matrix(v_batched):
  return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
12 µs ± 83.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
