## Results

### Matrix size 2048, 3 iterations, Python 3.12

| System         | Framework       | Accelerator    | Result    |
| -------------- | --------------- | -------------- | --------- |
| i5 + 4070ti    | JAX             | GPU            |   1.99 ms |
| M2 Max (30 c.) | JAX             | Metal GPU      |   6.07 ms |
| M2 Max (30 c.) | MLX.compile()   | Metal GPU      |   6.46 ms |
| M2 Max (30 c.) | MLX             | Metal GPU      |   6.56 ms |
| Colab          | JAX             | TPU v4         |  12.5 ms  |
| M2 Max (30 c.) | Torch.compile() | Metal GPU      |  28.3 ms  |
| M2 Max (30 c.) | Torch           | Metal GPU      |  30.6 ms  |
| M2 Max (30 c.) | Numpy 2.1.1     | Accelerate     |  33.8 ms  |
| i5 + 4070ti    | Torch.compile() | GPU            |  84.5 ms  |
| i5 + 4070ti    | Torch           | GPU            |  91.5 ms  |
| Intel i5 13500 | Numpy 2.1.2     | CPU            |  93.1 ms  |
| Google Colab   | Numpy 1.26.x    | CPU            | 736.0 ms  |

Note: `torch.compile()` currently requires minimum Torch 2.4 on Python 3.12 both for Nvidia and Apple.

## Numpy reference

In [1]:
import numpy as np

In [2]:
np.version.version

'2.1.3'

In [3]:
x = np.random.rand(2048, 2048).astype(dtype=np.float32) / 5.0

In [4]:
def bench_func(x):
    for i in range(3):
        x = (np.matmul(x,x)+x)/1000.0
    return x

In [5]:
%timeit bench_func(x)

69 ms ± 1.83 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


## JAX

In [6]:
import jax
from jax import jit
import jax.numpy as jnp

In [7]:
jax.__version__

'0.4.34'

In [8]:
xj = jnp.array(x)

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!


Metal device set to: Apple M2 Max


I0000 00:00:1750593949.527028 5227974 service.cc:145] XLA service 0x60000170c400 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1750593949.527039 5227974 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1750593949.528337 5227974 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1750593949.528352 5227974 mps_client.cc:384] XLA backend will use up to 22906109952 bytes on device 0 for SimpleAllocator.


In [9]:
def bench_func_j(x):
    for i in range(3):
        x = (jnp.matmul(x,x)+x)/1000.0
    return x

In [10]:
%timeit jit(bench_func_j)(xj).block_until_ready()

7.37 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Torch

In [11]:
import torch

In [12]:
torch.__version__

'2.7.1'

In [13]:
xt = torch.tensor(x)

In [14]:
def bench_func_t(x):
    for i in range(3):
        x = (torch.matmul(x,x)+x)/1000.0
    return x

In [15]:
%timeit bench_func_t(xt)

116 ms ± 24.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [16]:
bench_func_tc = torch.compile(bench_func_t)

In [17]:
%timeit bench_func_tc(xt)

The slowest run took 5.43 times longer than the fastest. This could mean that an intermediate result is being cached.
70.9 ms ± 51.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## MLX

In [18]:
import mlx
import mlx.core as mx

In [19]:
mx.__version__

'0.26.1'

In [20]:
xm = mx.array(x)

In [21]:
def bench_func_m(x1):
    for _ in range(3):
        x1 = (mx.matmul(x1,x1) +x1)/mx.array(1000.0)
    return x1

In [22]:
%timeit mx.eval(bench_func_m(xm))

9.79 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [23]:
bench_func_mc = mx.compile(bench_func_m)

In [24]:
b=bench_func_mc(xm)

In [25]:
%timeit mx.eval(bench_func_mc(xm))

8.98 ms ± 1.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
