Skip to content

Poor multithreading even for simple vectorization #32707

@SUSYUSTC

Description

@SUSYUSTC

Description

For this simple vectorization, jax.jit only uses 40% CPU on 8 core machine, and uses 15% CPU on 32 core machine. The input array x has length 5 million. If I simply implement it in torch, numba, or C++ with openmp, I could all get 100% CPU usage and much faster timing. I understand that vectorization is different from parallization, but I don't understand why jax just can't have full CPU usage for such a simple example.

@jax.jit
def sin_jax(x):
    for _ in range(100):
        x = jnp.sin(x)
    return x

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.2.0
python: 3.11.11 (main, Dec 11 2024, 16:28:39) [GCC 11.2.0]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='jiace-XPS-8930', release='6.14.0-33-generic', version='#33~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 17:02:30 UTC 2', machine='x86_64')

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions