## Results

### Matrix size 2024, 3 iterations, Python 3.12

| System         | Framework       | Accelerator    | Result    |
| -------------- | --------------- | -------------- | --------- |
| i5 + 4070ti    | JAX             | GPU            |   1.99 ms |
| M2 Max (30 c.) | JAX             | Metal GPU      |   6.2 ms  |
| Colab          | JAX             | TPU v4         |  12.5 ms  |
| M2 Max (30 c.) | Torch.compile() | Metal GPU      |  28.3 ms  |
| M2 Max (30 c.) | Torch           | Metal GPU      |  30.6 ms  |
| i5 + 4070ti    | Torch.compile() | GPU            |  84.5 ms  |
| i5 + 4070ti    | Torch           | GPU            |  91.5 ms  |
| Intel i5 13500 | Numpy           | CPU            | 125 ms    |
| M2 Max         | Numpy           | CPU            | 153 ms    |
| Colab          | Numpy           | CPU            | 736 ms    |

Note: `torch.compile()` requires currently nightly of Torch 2.4 on Python 3.12 both for Nvidia and Apple.


## Numpy reference

In [1]:
import numpy as np

In [16]:
np.version.version

'1.26.4'

In [2]:
x = np.random.rand(2048, 2048).astype(dtype=np.float32) / 5.0

In [9]:
def bench_func(x):
    for i in range(3):
        x = (np.matmul(x,x)+x)/1000.0
    return x

In [10]:
%timeit bench_func(x)

153 ms ± 3.43 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## JAX

In [11]:
import jax
from jax import jit
import jax.numpy as jnp

In [17]:
jax.__version__

'0.4.26'

In [12]:
xj = jnp.array(x)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M2 Max




In [13]:
def bench_func_j(x):
    for i in range(3):
        x = (jnp.matmul(x,x)+x)/1000.0
    return x

In [14]:
%timeit jit(bench_func_j)(xj).block_until_ready()

6.2 ms ± 22.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Torch

In [3]:
import torch

In [18]:
torch.__version__

'2.4.0.dev20240428'

In [4]:
xt = torch.tensor(x)

In [5]:
def bench_func_t(x):
    for i in range(3):
        x = (torch.matmul(x,x)+x)/1000.0
    return x

In [6]:
%timeit bench_func_t(xt)

30.6 ms ± 68.3 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
bench_func_tc = torch.compile(bench_func_t)

In [8]:
%timeit bench_func_tc(xt)

28.3 ms ± 200 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
