In [2]:
import jax
import jax.numpy as jnp
import numpy as np

In [3]:
def f(x):
    y1 = x + x*x + 3
    y2 = x*x + x*x.T

    return y1 * y2

x = np.random.randn(3000, 3000).astype('float32')
jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('gpu')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])

jax_f_cpu = jax.jit(f, backend='cpu')
jax_f_gpu = jax.jit(f, backend='gpu')

# warmup
jax_f_cpu(jax_x_cpu)
jax_f_gpu(jax_x_gpu)


Array([[ 4.2908535e+00, -3.5497880e-01,  3.2365820e+00, ...,
         1.7582173e+00,  1.6552262e+00,  4.4780904e-01],
       [ 5.5248916e-01,  4.9810520e-01,  3.3992612e+00, ...,
        -4.8352739e-01,  1.3106327e+00,  7.4119699e-01],
       [-6.2175333e-01,  1.8980707e+00,  2.8199699e+01, ...,
        -9.3777597e-01,  2.9834471e+00, -6.9638664e-01],
       ...,
       [ 9.9432975e-02,  1.6582045e+01,  5.9136353e+00, ...,
         2.1347990e+00,  2.2645648e+00,  1.1824804e-01],
       [ 6.6936164e+00, -1.7576742e-01,  3.1529326e+00, ...,
        -1.0557010e-01,  4.5092940e-01,  6.1506230e-01],
       [ 9.1438246e+00,  2.7565269e+00,  3.6416953e+00, ...,
         1.5194368e+00,  2.6418930e-02,  7.2142425e+00]], dtype=float32)

In [4]:
%timeit -n100 f(x)

51 ms ± 910 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
%timeit -n100 f(jax_x_cpu).block_until_ready()

106 ms ± 6.56 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()

14.8 ms ± 871 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit -n100 f(jax_x_gpu).block_until_ready()

1.27 ms ± 538 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()

268 μs ± 59.6 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
