## Comparing the jax.numpy.fft with numpy.fft

In [3]:
import jax.numpy as jnp
from jax import jit
import numpy as np
import numpy.testing as npt

from sklearn.metrics import mean_squared_error
import scipy

# Generate test data
np.random.seed(0)  # Ensure reproducibility
data = np.random.rand(256, 256)  # 2D array of random numbers

# Adjusted tolerance values
atol = 1e-6  # Absolute tolerance
rtol = 1e-6  # Relative tolerance

a = np.fft.fft2(data).flatten()
b = jnp.fft.fft2(data).flatten()
c = scipy.fft.fft2(data).flatten()

a_real = a.real.flatten()
a_imag = a.imag.flatten()
b_real = b.real.flatten()
b_imag = b.imag.flatten()
c_real = c.real.flatten()
c_imag = c.imag.flatten()

# print(np.fft.fft2(data) - jnp.fft.fft2(data))
# print('####')
# print(np.fft.fft2(data) - scipy.fft.fft2(data))

print(np.corrcoef(a_real, b_real), np.corrcoef(a_imag, b_imag), mean_squared_error(a_real, b_real), mean_squared_error(a_imag, b_imag))
print('#')
print(np.corrcoef(a_real, c_real), np.corrcoef(a_imag, c_imag), mean_squared_error(a_real, c_real), mean_squared_error(a_imag, c_imag))
print('#')
print(np.corrcoef(b_real, c_real), np.corrcoef(b_imag, c_imag), mean_squared_error(b_real, c_real), mean_squared_error(b_imag, c_imag))
# print(jnp.fft.fft2(data))



[[1. 1.]
 [1. 1.]] [[1. 1.]
 [1. 1.]] 1.3331533165183426e-10 1.7689410108875597e-10
#
[[1. 1.]
 [1. 1.]] [[1. 1.]
 [1. 1.]] 3.406227136621809e-28 3.415452867676893e-28
#
[[1. 1.]
 [1. 1.]] [[1. 1.]
 [1. 1.]] 1.3331533163442366e-10 1.7689410107305212e-10


In [19]:
import jax.numpy as jnp
from jax import jit
import numpy as np
import time

# Define the JAX function and JIT compile it
@jit
def jax_func(x):
    return jnp.sin(x) ** 2 + jnp.cos(x) ** 2

# Define the NumPy function (not JIT compiled, for comparison)
def numpy_func(x):
    return np.sin(x) ** 2 + np.cos(x) ** 2

@jit
def numpy_jax_func(x):
    return np.sin(x) ** 2 + np.cos(x) ** 2

# Generate a large NumPy array as input
x = np.random.rand(10000)

# Warm-up and compile the JAX function
_ = jax_func(x)
_ = numpy_jax_func(x)

# Measure JAX function time
start_time = time.time()
for _ in range(500):  # Run the JAX function 500 times
    _ = jax_func(x)
jax_time = time.time() - start_time

# Measure NumPy function time
start_time = time.time()
for _ in range(500):  # Run the NumPy function 500 times
    _ = numpy_func(x)
numpy_time = time.time() - start_time

# Measure NumPy function time
start_time = time.time()
for _ in range(500):  # Run the NumPy function 500 times
    _ = numpy_jax_func(x)
numpy_jax_time = time.time() - start_time

print(f"JAX computation time: {jax_time} seconds")
print(f"NumPy-JAX computation time: {numpy_jax_time} seconds")
print(f"NumPy computation time: {numpy_time} seconds")


TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[10000].
The error occurred while tracing the function numpy_jax_func at /var/folders/x6/fx3v22fd3h33fqnrs23l8_sh0000gn/T/ipykernel_78474/457597371.py:15 for jit. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerArrayConversionError