In [2]:
import numpy as np
from numba import njit, prange
import jax
import jax.numpy as jnp
import time

# Numba parallel sum
@njit(parallel=True)
def numba_sum(arr):
    total = 0.0
    for i in prange(len(arr)):
        total += arr[i]
    return total

# JAX sum
@jax.jit
def jax_sum(arr):
    return jnp.sum(arr)

In [3]:
arr_np = np.random.rand(10_000_000).astype(np.float64)
arr_jax = jnp.array(arr_np)

# Warm-up
numba_sum(arr_np);
jax_sum(arr_jax).block_until_ready();

In [4]:
# Timing
start = time.time()
numba_sum(arr_np)
print("Numba time:", time.time() - start)

start = time.time()
jax_sum(arr_jax).block_until_ready()
print("JAX time:", time.time() - start)

Numba time: 0.0033452510833740234
JAX time: 0.0074198246002197266


In [5]:
# JAX version
import jax
import jax.numpy as jnp
import time

x = jnp.ones((8192, 8192))
y = jnp.ones((8192, 8192))

jax.block_until_ready(x @ y)  # warm-up
start = time.time()
out = x @ y
jax.block_until_ready(out)
print("JAX time:", time.time() - start)

JAX time: 1.7573049068450928


In [6]:
!pip install mlx

  pid, fd = os.forkpty()


Collecting mlx
  Downloading mlx-0.26.1-cp310-cp310-macosx_15_0_arm64.whl.metadata (5.3 kB)
Downloading mlx-0.26.1-cp310-cp310-macosx_15_0_arm64.whl (31.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.9/31.9 MB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: mlx
Successfully installed mlx-0.26.1


In [7]:
# MLX version
import mlx.core as mx
import time

a = mx.ones((8192, 8192))
b = mx.ones((8192, 8192))

mx.eval(a @ b)  # warm-up
start = time.time()
c = a @ b
mx.eval(c)
print("MLX time:", time.time() - start)


MLX time: 0.2774200439453125
