In [2]:
import os

# do not prealocate memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

# Set cuda device to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [9]:
from jax import lax, jit, vmap
from functools import partial

import jax.numpy as jnp
import jax.random as jr
import jax
import numpy as np
jax.__version__

'0.4.23'

In [4]:
x = jr.normal(jr.PRNGKey(0), shape=(10_000, 10_000))

%timeit jnp.dot(x, x.T).block_until_ready()

58.5 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
x_batch = x.reshape(-1, 100, 100)

%timeit lax.batch_matmul(x_batch, x_batch).block_until_ready()

8.95 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
batch_matmul = vmap(lambda x, y: jnp.dot(x, y))

%timeit batch_matmul(x_batch, x_batch).block_until_ready()

8.98 ms ± 50.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
cpu = jax.devices('cpu')[0]
x_cpu = jax.device_put(x, cpu)
x_batch_cpu = jax.device_put(x_batch, cpu)

%timeit jnp.dot(x_cpu, x_cpu).block_until_ready()
%timeit lax.batch_matmul(x_batch_cpu, x_batch_cpu).block_until_ready()
%timeit batch_matmul(x_batch_cpu, x_batch_cpu).block_until_ready()

1.56 s ± 38.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
49 ms ± 248 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
49.5 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
