-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Description
For this simple vectorization, jax.jit only uses 40% CPU on 8 core machine, and uses 15% CPU on 32 core machine. The input array x has length 5 million. If I simply implement it in torch, numba, or C++ with openmp, I could all get 100% CPU usage and much faster timing. I understand that vectorization is different from parallization, but I don't understand why jax just can't have full CPU usage for such a simple example.
@jax.jit
def sin_jax(x):
for _ in range(100):
x = jnp.sin(x)
return x
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.2.0
python: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='jiace-XPS-8930', release='6.14.0-33-generic', version='#33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2', machine='x86_64')