In [None]:
import numpy as np
import jax.numpy as jnp
import jax
from jax import config

# Set 64-bit precision explicitly
config.update("jax_enable_x64", True)

print('Using:', jax.devices())

# Function we're benchmarking (works in both NumPy & JAX)
def f(x):  
  return x.T @ (x - x.mean(axis=0))

x_np = np.ones((1000, 1000)) 
%timeit f(x_np)  # measure NumPy runtime

%time x_jax = jax.device_put(x_np)  # measure JAX device transfer time
f_jit = jax.jit(f)
%time f_jit(x_jax).block_until_ready()  # measure JAX compilation time
%timeit f_jit(x_jax).block_until_ready()  # measure JAX runtime