Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax eisum has different results between CPU and GPU #22557

Closed
John-zzh opened this issue Jul 22, 2024 · 6 comments
Closed

jax eisum has different results between CPU and GPU #22557

John-zzh opened this issue Jul 22, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@John-zzh
Copy link

John-zzh commented Jul 22, 2024

Description

Hi there,
My jax.numpy.einsum has awful accuracy on GPU device, but seems no problem on CPU.

This is runing on GPU

import jax
import jax.numpy as jnp

A = np.random.rand(300,300)
B = np.random.rand(300,300,4)

numpy_result_double        = np.einsum("ab,caP->cbP", A, B)
numpy_result_single        = np.einsum("ab,caP->cbP", A.astype(np.float32), B.astype(np.float32))

jax_result                 = jnp.einsum("ab,caP->cbP", jnp.array(A), jnp.array(B))

print(np.linalg.norm(numpy_result_double - numpy_result_single))
print(np.linalg.norm(         jax_result - numpy_result_single))

I got this reslt:

0.011850594613419469
0.9229351

but when running on CPU

import numpy as np
import jax
import jax.numpy as jnp
# running on CPU
jax.config.update('jax_platform_name', 'cpu')

A = np.random.rand(300,300)
B = np.random.rand(300,300,4)

numpy_result_double        = np.einsum("ab,caP->cbP", A, B)
numpy_result_single        = np.einsum("ab,caP->cbP", A.astype(np.float32), B.astype(np.float32))

jax_result                 = jnp.einsum("ab,caP->cbP", jnp.array(A), jnp.array(B))

print(np.linalg.norm(numpy_result_double - numpy_result_single))
print(np.linalg.norm(         jax_result - numpy_result_single))

and I got

0.011844848124626908
0.012697785

Both jax 0.4.28 and 0.4.30 have same issue on my machine.
Is it becasue of WSL2 environment or specific installation way of jax?
I used conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia, and jax.print_environment_info() says

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28
jaxlib: 0.4.28.dev20240711
numpy:  1.25.2
python: 3.9.18 | packaged by conda-forge | (main, Dec 23 2023, 16:33:10)  [GCC 12.3.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='jojolaptop', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')

$ nvidia-smi
Mon Jul 22 17:46:39 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.58.02              Driver Version: 556.12         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4050 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   46C    P3              9W /   35W |     270MiB /   6141MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     56103      C   /python3.9                                  N/A      |
+-----------------------------------------------------------------------------------------+

I also tried pip install jax[cuda12] and jax.print_environment_info() says

jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.1
python: 3.12.4 | packaged by Anaconda, Inc. | (main, Jun 18 2024, 15:12:24) [GCC 11.2.0]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='jojolaptop', release='5.15.153.1-microsoft-standard-WSL2', version='#1 SMP Fri Mar 29 23:14:13 UTC 2024', machine='x86_64')


$ nvidia-smi
Mon Jul 22 18:16:28 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 555.58.02              Driver Version: 556.12         CUDA Version: 12.5     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4050 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   51C    P3              9W /   35W |     299MiB /   6141MiB |      6%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     70120      C   /python3.12                                 N/A      |
+-----------------------------------------------------------------------------------------+
@John-zzh John-zzh added the bug Something isn't working label Jul 22, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Jul 22, 2024

I suspect this is a matmul precision configuration issue.

JAX's dot-like operations default to a lower precision on some accelerators for performance reasons. If you want to force higher precision, you can do so either via the precision argument to einsum and other dot-like operations

jnp.einsum("ab,caP->cbP", A, B, precision='highest')

or you can modify the value globally using the jax_default_matmul_precision configuration:

jax.config.update('jax_default_matmul_precision', 'highest')

@jakevdp jakevdp self-assigned this Jul 22, 2024
@John-zzh
Copy link
Author

thx! that works! now GPU is as accurate as CPU.

When should I use jax.config.update('jax_default_matmul_precision', 'highest')?
Two my friends tried my code above without setting 'jax_default_matmul_precision', and they are just fine. I feel like it is system dependent.

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 23, 2024

I feel like it is system dependent.

Precisely. Currently the precise meaning of matmul precision settings varies by hardware. For some GPU chips, 'high' and 'highest' are equivalent, while for others they aren't. It's not in a great state currently, and #18934 tracks improving this.

@kcdodd
Copy link

kcdodd commented Jul 27, 2024

Having just spent hours debugging why a particular computation was an order of magnitude different in one version of jax/cuda versus another (on the same machine, and both using the same GPU hardware), I was fairly disappointed to find that the culprit was use of einsum and the counter-intuitive consequence of its default precision. My opinion is that whatever is chosen on how to specify a lower precision for matmul in general, the default should match what would be expected from the dtype of the arguments, for ubiquitous operations that are also presented as a replacement for the numpy version.

While I understand the motivation of performance trade-off for some machine learning applications, this makes it questionable to use for more mathematically/scientifically rigorous applications where algorithms depend on certain assumptions about floating point precision and reproducibility.

@John-zzh
Copy link
Author

Having just spent hours debugging why a particular computation was an order of magnitude different in one version of jax/cuda versus another (on the same machine, and both using the same GPU hardware), I was fairly disappointed to find that the culprit was use of einsum and the counter-intuitive consequence of its default precision. My opinion is that whatever is chosen on how to specify a lower precision for matmul in general, the default should match what would be expected from the dtype of the arguments, for ubiquitous operations that are also presented as a replacement for the numpy version.

While I understand the motivation of performance trade-off for some machine learning applications, this makes it questionable to use for more mathematically/scientifically rigorous applications where algorithms depend on certain assumptions about floating point precision and reproducibility.

thanks for your effort. It does look dangerous to use a different default precision in scientific computing.
So is this bug fixed now? or maybe update in next release?

@jakevdp
Copy link
Collaborator

jakevdp commented Aug 8, 2024

I think I'm going to close this issue now, because the root cause is tracked by #18934

@jakevdp jakevdp closed this as completed Aug 8, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants