-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax eisum has different results between CPU and GPU #22557
Comments
I suspect this is a matmul precision configuration issue. JAX's dot-like operations default to a lower precision on some accelerators for performance reasons. If you want to force higher precision, you can do so either via the jnp.einsum("ab,caP->cbP", A, B, precision='highest') or you can modify the value globally using the jax.config.update('jax_default_matmul_precision', 'highest') |
thx! that works! now GPU is as accurate as CPU. When should I use |
Precisely. Currently the precise meaning of matmul precision settings varies by hardware. For some GPU chips, |
Having just spent hours debugging why a particular computation was an order of magnitude different in one version of jax/cuda versus another (on the same machine, and both using the same GPU hardware), I was fairly disappointed to find that the culprit was use of While I understand the motivation of performance trade-off for some machine learning applications, this makes it questionable to use for more mathematically/scientifically rigorous applications where algorithms depend on certain assumptions about floating point precision and reproducibility. |
thanks for your effort. It does look dangerous to use a different default precision in scientific computing. |
I think I'm going to close this issue now, because the root cause is tracked by #18934 |
Description
Hi there,
My jax.numpy.einsum has awful accuracy on GPU device, but seems no problem on CPU.
This is runing on GPU
I got this reslt:
but when running on CPU
and I got
Both jax 0.4.28 and 0.4.30 have same issue on my machine.
Is it becasue of WSL2 environment or specific installation way of jax?
I used
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
, andjax.print_environment_info()
saysSystem info (python version, jaxlib version, accelerator, etc.)
I also tried
pip install jax[cuda12]
andjax.print_environment_info()
saysThe text was updated successfully, but these errors were encountered: