In [1]:
import jax
import jax.numpy as jnp

In [2]:
key = jax.random.PRNGKey(0)

In [3]:
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))

In [4]:
def apply_matrix(v):
  return jnp.dot(mat, v)

In [6]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
2.4 ms ± 151 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
@jax.jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
45.6 µs ± 12.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
@jax.jit
def vmap_batched_apply_matrix(v_batched):
  return jax.vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
44.8 µs ± 4.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
