-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Matrix multiplication inaccurate on A100 in both TF32 and FP32 #19444
Comments
Have you seen the jax.config.update("jax_default_matmul_precision", "highest") Alternatively, you can adjust the value locally using a context manager: def get_T(Ohat):
with jax.default_matmul_precision('float32'):
T = Ohat @ Ohat.T
m = T.shape[0]
T = T / m
return T Do you see the expected accuracy when you run the code this way? |
Thanks for the quick response! Just tried those on the A100, using either suggestion I then get the result of |
I looked at the generated code and the compiler is just offloading to cuBLAS for the matrix multiplication. When I tried to look at the precision of individual elements of the matrix, they seem to be well within the range that comes from different re-association of the additions in the matrix multiplication. |
I initially intended to create a new issue regarding numerical precision issues, but on discovering this one, it seems more appropriate to post here since the problems appear to be closely related or even the same. In my case, the issue manifested itself as an unexpectedly high error in the associativity of matrix multiplication, as demonstrated by this example script: import numpy as np
import jax
import jax.numpy as jnp
def dev_info():
dev = jax.devices()[0]
info = "CPU" if dev.platform == "cpu" else dev.device_kind
return info
def relres(ax, b):
nrm = max(jnp.linalg.norm(ax.ravel()), jnp.linalg.norm(b.ravel()))
if nrm == 0.0:
return 0.0
return jnp.linalg.norm((b - ax).ravel()) / nrm
def mx_mul_assoc_error(A, B, x):
abx1 = (A @ B) @ x
abx2 = A @ (B @ x)
return relres(abx1, abx2)
np.random.seed(1234)
dtype = np.float32
M, N = (8, 16)
A = jnp.array(np.random.randn(M, N).astype(dtype=dtype))
B = jnp.array(np.random.randn(N, M).astype(dtype=dtype))
x = jnp.array(np.random.randn(M, 1).astype(dtype=dtype))
print(f"Running test on device {dev_info()}. Matrix mult. assoc. error at:")
print(f" default matmul precision: {mx_mul_assoc_error(A, B, x):.3e}")
jax.config.update("jax_default_matmul_precision", "high")
print(f" high matmul precision: {mx_mul_assoc_error(A, B, x):.3e}")
jax.config.update("jax_default_matmul_precision", "highest")
print(f" highest matmul precision: {mx_mul_assoc_error(A, B, x):.3e}") On the relatively low performance RTX 2080 the output is
while on the much more powerful A100 it's
Aside from whether "high" is a more appropriate choice of default precision than "highest" (as discussed in #2161), a relative error of 3.7e-4 seems remarkably high for "high" precision multiplication of such small matrices. Some specific questions:
|
Hi - I think the larger issue here is the lack of clarity around the meaning of each precision setting on each platform. This is tracked in #18934, but there's not been much progress because it's complicated. To try to answer your questions:
I think this is working as intended. Accelerators have different hardware matmul implementations, but in general they favor speed over accuracy with the default setting.
No,
The meaning of "high" and "highest" is dependent on the hardware you are using.
Agreed – this is something that we hope to address, but it's difficult. #18934 mentions these issues explicitly, and is probably the relevant issue to follow. |
Understood. Thanks for the thorough response! |
Issue Description
In my code I use JAX to calculate an
m x n
matrix that I callOhat
, withm << n
. I then calculate a squarem x m
matrixT = Ohat @ Ohat.T / m
, and my code relies on the fact thatT
is positive semidefinite up to some small numerical error. When running on a GTX 2080Ti GPU, this assumption is satisfied with error of approximately 1e-5, meaning that the smallest eigenvalue ofT
is no smaller than about-1e-5
. However when I run the same code on an A100 GPU the error is multiple orders of magnitude bigger, specifically I find thatT
has an eigenvalue of about-2e-3
. This is large enough to cause my code to fail, and the fact that there is a gap of two orders of magnitude makes me think there is a bug with JAX or with the GPUs themselves.I initially thought the issue was due to the use of TensorFloat32, but the problem persists and in fact gets even worse (smallest eigenvalue becomes about
-2e-2
) when I turn TF32 off using the flagNVIDIA_TF32_OVERRIDE=0
.I also thought the issue might be in the
eigh
function, but I was able to verify that the issue happens when and only when the matrix multiplication is done in JAX. For example, if I doT = Ohat @ Ohat.T / m
in JAX and then usenp.linalg.eigh
to find the smallest eigenvalue ofT
, the issue persists. On the other hand if I calculate T in numpy and then use JAX for the eigh call, the issue is gone.My 2080Ti GPU is on the Savio cluster at Berkeley (https://docs-research-it.berkeley.edu/services/high-performance-computing/user-guide/) and my A100 is on Perlmutter (https://docs.nersc.gov/systems/perlmutter/). I previously corresponded with the help desk on Perlmutter and they tried the issue on some other clusters and found that it reproduces on A100's on several other clusters; hence why we think the issue is deeper than the specific cluster.
I've tried both Cuda 11 and 12 with corresponding versions of Jax and Jaxlib on both clusters, and the results are the same regardless of which Cuda version I use (errors much bigger on A100s in either case).
Any help is appreciated!
How to reproduce the issue
Here is a problematic Ohat matrix stored on GDrive: https://drive.google.com/file/d/1021CTuND0sBOshmyfc7Gt6PQ9EK520-Q/view?usp=drive_link. This matrix is of shape (1000, 653360).
The below script reproduces the issue using this matrix. I found that using jit on the
get_T
function reduces memory overhead and makes it easier to run the script on different hardware; however the issue should equally reproduce without the jit call. Also note that the script only uses JAX to form T, not to calculate the eigenvalue decomposition. You can verify by tweaking the script that the numerical error is large when and only when the T matrix is calculated using JAX.My results
On A100: -0.0022894708
On A100, with
NVIDIA_TF32_OVERRIDE=0
: -0.26022542On GTX 2080Ti: -2.1887377e-06
Version info
What jax/jaxlib version are you using?
Jax 0.4.23, jaxlib 0.4.23+cuda12.cudnn89
Which accelerator(s) are you using?
A100 GPU
Additional system info?
uname_result(system='Linux', node='nid001644', release='5.14.21-150400.24.81_12.0.87-cray_shasta_c', version='#1 SMP Mon Nov 6 21:51:33 UTC 2023 (e30c7c1)', machine='x86_64')
sys.version: 3.9.16 (main, Mar 8 2023, 14:00:05)
[GCC 11.2.0]
Numpy.version: 1.25.2
NVIDIA GPU info
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 |
| N/A 31C P0 51W / 400W | 0MiB / 40960MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered: