From https://gist.github.com/shoyer/4e0328c277e46f58c47d79b85a51aa0a

In [4]:
import os

# do not prealocate memory
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

# Set cuda device to use
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [6]:
from jax import lax, jit
from functools import partial

import jax.numpy as jnp
import jax
import numpy as np
jax.__version__

'0.4.23'

In [7]:
@partial(jit, static_argnames=['unroll'], backend='cpu')
def polyval(p, x, unroll=64):
  shape = lax.broadcast_shapes(p.shape[1:], x.shape)
  dtype = jnp.result_type(p, x)
  y = lax.full_like(x, 0, shape=shape, dtype=dtype)
  y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
  return y


x = np.random.rand(100).astype(np.float32)
p = np.random.randn(10000).astype(np.float32)

print("CPU")
for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:
  print(f"unroll={unroll}")
  %time polyval(p, x, unroll).block_until_ready()
  %timeit polyval(p, x, unroll).block_until_ready()

CPU
unroll=1
CPU times: user 56.8 ms, sys: 166 ms, total: 223 ms
Wall time: 295 ms
82.6 µs ± 142 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
unroll=2
CPU times: user 24.3 ms, sys: 3.62 ms, total: 27.9 ms
Wall time: 27.8 ms
43.6 µs ± 1.15 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
unroll=4
CPU times: user 24.6 ms, sys: 3.58 ms, total: 28.1 ms
Wall time: 28 ms
33.5 µs ± 959 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
unroll=8
CPU times: user 24.5 ms, sys: 6.83 ms, total: 31.4 ms
Wall time: 31.1 ms
37.4 µs ± 1.82 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
unroll=16
CPU times: user 36.9 ms, sys: 3.37 ms, total: 40.3 ms
Wall time: 40.6 ms
50 µs ± 12.2 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
unroll=32
CPU times: user 73.8 ms, sys: 80 µs, total: 73.9 ms
Wall time: 74.3 ms
73.5 µs ± 647 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
unroll=64
CPU times: user 103 ms, sys: 32 µs, total: 103 ms
Wall

In [8]:
@partial(jit, static_argnames=['unroll'], backend='gpu')
def polyval(p, x, unroll=64):
  shape = lax.broadcast_shapes(p.shape[1:], x.shape)
  dtype = jnp.result_type(p, x)
  y = lax.full_like(x, 0, shape=shape, dtype=dtype)
  y, _ = lax.scan(lambda y, p: (y * x + p, None), y, p, unroll=unroll)
  return y


x = jax.device_put(np.random.rand(100))
p = jax.device_put(np.random.randn(10000))

print("GPU")
for unroll in [1, 2, 4, 8, 16, 32, 64, 128]:
  print(f"unroll={unroll}")
  %time polyval(p, x, unroll).block_until_ready()
  %timeit polyval(p, x, unroll).block_until_ready()

GPU
unroll=1
CPU times: user 109 ms, sys: 6.73 ms, total: 116 ms
Wall time: 152 ms
46.6 ms ± 1.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
unroll=2
CPU times: user 71 ms, sys: 50 µs, total: 71 ms
Wall time: 78.1 ms
23 ms ± 434 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
unroll=4
CPU times: user 58.9 ms, sys: 98 µs, total: 59 ms
Wall time: 67.9 ms
11.2 ms ± 11.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
unroll=8
CPU times: user 53.5 ms, sys: 3.34 ms, total: 56.8 ms
Wall time: 64.2 ms
5.66 ms ± 21.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
unroll=16
CPU times: user 69.2 ms, sys: 0 ns, total: 69.2 ms
Wall time: 76.5 ms
2.87 ms ± 7.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
unroll=32
CPU times: user 122 ms, sys: 10.1 ms, total: 132 ms
Wall time: 128 ms
2.18 ms ± 2.74 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
unroll=64
CPU times: user 243 ms, sys: 238 µs, total: 243 ms
Wall time: 241 ms
1.12 ms ± 575 ns