## Results

### Matrix size 2024, 3 iterations

| System         | Framework       | Accelerator    | Result    |
| -------------- | --------------- | -------------- | --------- |
| i5 + 4070ti    | JAX             | GPU            |   1.99 ms |
| i5 + 4070ti    | Torch.compile() | GPU            |  84.5 ms  |
| i5 + 4070ti    | Torch           | GPU            |  91.5 ms  |
| Intel i5 13500 | Numpy           | CPU            | 125 ms    |




## Numpy reference

In [11]:
import numpy as np

In [12]:
x = np.random.rand(2048, 2048).astype(dtype=np.float32) / 5.0

In [13]:
def bench_func(x):
    for i in range(3):
        x = (np.matmul(x,x)+x)/1000.0
    return x

In [14]:
%timeit bench_func(x)

125 ms ± 352 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


## JAX

In [16]:
import jax
from jax import jit
import jax.numpy as jnp

In [17]:
xj = jnp.array(x)

In [18]:
def bench_func_j(x):
    for i in range(3):
        x = (jnp.matmul(x,x)+x)/1000.0
    return x

In [19]:
%timeit jit(bench_func_j)(xj).block_until_ready()

1.99 ms ± 73.8 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Torch

In [20]:
import torch

In [21]:
xt = torch.tensor(x)

In [22]:
def bench_func_t(x):
    for i in range(3):
        x = (torch.matmul(x,x)+x)/1000.0
    return x

In [23]:
%timeit bench_func_t(xt)

91.5 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [24]:
bench_func_tc = torch.compile(bench_func_t)

In [25]:
%timeit bench_func_tc(xt)

84.5 ms ± 72.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
