In [22]:
import jax
from jax import jit
import jax.numpy as jnp
import numpy as np

In [30]:
x = np.random.rand(2048, 2048).astype(dtype=np.float32)

In [31]:
xj = jnp.array(x)

In [32]:
def bench_func(x):
    for i in range(10):
        x = (x*x+x)/2.0
    return x

In [33]:
bench_func(x)

array([[1.7461327e-03, 1.8241390e-04, 8.1232237e-03, ..., 1.0861983e-02,
        1.4884750e-03, 2.0174329e-04],
       [1.1372954e-03, 1.4585739e-03, 2.4806077e-03, ..., 3.5619509e-04,
        9.1104209e-03, 1.0708585e-02],
       [2.6197333e-03, 4.5072101e-04, 5.3375732e-04, ..., 9.8322565e-03,
        6.0740311e-02, 8.1474639e-02],
       ...,
       [4.2800358e-03, 7.8390836e-04, 8.6207455e-03, ..., 3.4268037e-03,
        1.8181236e-03, 3.3391663e-04],
       [8.6220032e-05, 2.4278357e-04, 2.8487208e-04, ..., 1.8660809e-04,
        3.2633713e-03, 1.4509721e-02],
       [4.0727336e-04, 2.6338474e-05, 1.5617841e-01, ..., 8.4046302e-03,
        6.5404701e-04, 3.5888283e-04]], dtype=float32)

In [34]:
%timeit bench_func(x)

19.8 ms ± 124 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [35]:
%timeit jit(bench_func)(xj).block_until_ready()

783 µs ± 5.31 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Results

| System | Framework | Accelerator | Result |
| ------ | --------- | ----------- | ------- |
| M2Max  | Numpy     | CPU         | 19.8 ms |
| M2Max  | JAX       | Metal (30 Core) | 783 us |
| Colab  | Numpy     | CPU         | 112 ms |
| Colab  | JAX       | TPU v4      | 492 us |