Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Matrix multiplication inaccurate on A100 in both TF32 and FP32 #19444

Open
ggoldsh opened this issue Jan 19, 2024 · 6 comments
Open

Matrix multiplication inaccurate on A100 in both TF32 and FP32 #19444

ggoldsh opened this issue Jan 19, 2024 · 6 comments
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs

Comments

@ggoldsh
Copy link

ggoldsh commented Jan 19, 2024

Issue Description

In my code I use JAX to calculate an m x n matrix that I call Ohat, with m << n. I then calculate a square m x m matrix T = Ohat @ Ohat.T / m, and my code relies on the fact that T is positive semidefinite up to some small numerical error. When running on a GTX 2080Ti GPU, this assumption is satisfied with error of approximately 1e-5, meaning that the smallest eigenvalue of T is no smaller than about -1e-5. However when I run the same code on an A100 GPU the error is multiple orders of magnitude bigger, specifically I find that T has an eigenvalue of about -2e-3. This is large enough to cause my code to fail, and the fact that there is a gap of two orders of magnitude makes me think there is a bug with JAX or with the GPUs themselves.

I initially thought the issue was due to the use of TensorFloat32, but the problem persists and in fact gets even worse (smallest eigenvalue becomes about -2e-2) when I turn TF32 off using the flag NVIDIA_TF32_OVERRIDE=0.

I also thought the issue might be in the eigh function, but I was able to verify that the issue happens when and only when the matrix multiplication is done in JAX. For example, if I do T = Ohat @ Ohat.T / m in JAX and then use np.linalg.eigh to find the smallest eigenvalue of T, the issue persists. On the other hand if I calculate T in numpy and then use JAX for the eigh call, the issue is gone.

My 2080Ti GPU is on the Savio cluster at Berkeley (https://docs-research-it.berkeley.edu/services/high-performance-computing/user-guide/) and my A100 is on Perlmutter (https://docs.nersc.gov/systems/perlmutter/). I previously corresponded with the help desk on Perlmutter and they tried the issue on some other clusters and found that it reproduces on A100's on several other clusters; hence why we think the issue is deeper than the specific cluster.

I've tried both Cuda 11 and 12 with corresponding versions of Jax and Jaxlib on both clusters, and the results are the same regardless of which Cuda version I use (errors much bigger on A100s in either case).

Any help is appreciated!

How to reproduce the issue

Here is a problematic Ohat matrix stored on GDrive: https://drive.google.com/file/d/1021CTuND0sBOshmyfc7Gt6PQ9EK520-Q/view?usp=drive_link. This matrix is of shape (1000, 653360).

The below script reproduces the issue using this matrix. I found that using jit on the get_T function reduces memory overhead and makes it easier to run the script on different hardware; however the issue should equally reproduce without the jit call. Also note that the script only uses JAX to form T, not to calculate the eigenvalue decomposition. You can verify by tweaking the script that the numerical error is large when and only when the T matrix is calculated using JAX.

import numpy as np
import jax
import jax.numpy as jnp

Ohat = jnp.load("Ohat.npy")

def get_T(Ohat):
  T = Ohat @ Ohat.T
  m = T.shape[0]
  T = T / m
  return T

get_T = jax.jit(get_T)
T = get_T(Ohat)

min_eig = np.linalg.eigh(T)[0][0]
print(min_eig)

My results

On A100: -0.0022894708
On A100, with NVIDIA_TF32_OVERRIDE=0: -0.26022542
On GTX 2080Ti: -2.1887377e-06

Version info

What jax/jaxlib version are you using?

Jax 0.4.23, jaxlib 0.4.23+cuda12.cudnn89

Which accelerator(s) are you using?

A100 GPU

Additional system info?

uname_result(system='Linux', node='nid001644', release='5.14.21-150400.24.81_12.0.87-cray_shasta_c', version='#1 SMP Mon Nov 6 21:51:33 UTC 2023 (e30c7c1)', machine='x86_64')

sys.version: 3.9.16 (main, Mar 8 2023, 14:00:05)
[GCC 11.2.0]

Numpy.version: 1.25.2

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17 Driver Version: 525.105.17 CUDA Version: 12.2 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 NVIDIA A100-SXM... On | 00000000:C1:00.0 Off | 0 |
| N/A 31C P0 51W / 400W | 0MiB / 40960MiB | 0% Default |
| | | Disabled |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+

@ggoldsh ggoldsh added the bug Something isn't working label Jan 19, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 19, 2024

Have you seen the precision argument to jnp.matmul and others? I'm not sure what the default setting is on A100, but you can force it to use the highest precision setting using the jax_default_matmul_precision flag. For example, at the top of your script you could write:

jax.config.update("jax_default_matmul_precision", "highest")

Alternatively, you can adjust the value locally using a context manager:

def get_T(Ohat):
  with jax.default_matmul_precision('float32'):
    T = Ohat @ Ohat.T
  m = T.shape[0]
  T = T / m
  return T

Do you see the expected accuracy when you run the code this way?

@ggoldsh
Copy link
Author

ggoldsh commented Jan 19, 2024

Thanks for the quick response!

Just tried those on the A100, using either suggestion I then get the result of -0.009320012. So, the result changed slightly but still is much worse than on my other cluster.

@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Jan 22, 2024
@jaro-sevcik
Copy link
Contributor

I looked at the generated code and the compiler is just offloading to cuBLAS for the matrix multiplication. When I tried to look at the precision of individual elements of the matrix, they seem to be well within the range that comes from different re-association of the additions in the matrix multiplication.

@bwohlberg
Copy link

I initially intended to create a new issue regarding numerical precision issues, but on discovering this one, it seems more appropriate to post here since the problems appear to be closely related or even the same. In my case, the issue manifested itself as an unexpectedly high error in the associativity of matrix multiplication, as demonstrated by this example script:

import numpy as np

import jax
import jax.numpy as jnp


def dev_info():
    dev = jax.devices()[0]
    info = "CPU" if dev.platform == "cpu" else dev.device_kind
    return info


def relres(ax, b):
    nrm = max(jnp.linalg.norm(ax.ravel()), jnp.linalg.norm(b.ravel()))
    if nrm == 0.0:
        return 0.0
    return jnp.linalg.norm((b - ax).ravel()) / nrm


def mx_mul_assoc_error(A, B, x):
    abx1 = (A @ B) @ x
    abx2 = A @ (B @ x)
    return relres(abx1, abx2)



np.random.seed(1234)
dtype = np.float32
M, N = (8, 16)
A = jnp.array(np.random.randn(M, N).astype(dtype=dtype))
B = jnp.array(np.random.randn(N, M).astype(dtype=dtype))
x = jnp.array(np.random.randn(M, 1).astype(dtype=dtype))

print(f"Running test on device {dev_info()}. Matrix mult. assoc. error at:")
print(f"    default matmul precision: {mx_mul_assoc_error(A, B, x):.3e}")
jax.config.update("jax_default_matmul_precision", "high")
print(f"    high matmul precision:    {mx_mul_assoc_error(A, B, x):.3e}")
jax.config.update("jax_default_matmul_precision", "highest")
print(f"    highest matmul precision: {mx_mul_assoc_error(A, B, x):.3e}")

On the relatively low performance RTX 2080 the output is

Running test on device NVIDIA GeForce RTX 2080 Ti. Matrix mult. assoc. error at:
    default matmul precision: 9.324e-08
    high matmul precision:    9.324e-08
    highest matmul precision: 9.324e-08

while on the much more powerful A100 it's

Running test on device NVIDIA A100-SXM4-80GB. Matrix mult. assoc. error at:
    default matmul precision: 3.650e-04
    high matmul precision:    3.650e-04
    highest matmul precision: 1.727e-07

Aside from whether "high" is a more appropriate choice of default precision than "highest" (as discussed in #2161), a relative error of 3.7e-4 seems remarkably high for "high" precision multiplication of such small matrices. Some specific questions:

  • Is this a bug, or is such low precision, which is clearly unusable for any serious numerical calculations, really considered acceptable for the "high" precision setting?
  • Is "high" the current default across all devices?
  • Is it clear why there is such a big difference between "high" and "highest" on an A100 while these levels give the same performance on an RTX 2080?
  • This huge difference in accuracy complicates writing code that's portable across different devices. If the observation for the A100 is not considered a bug, is there any other solution other than always globally applying the "highest" precision setting? And if it's not a bug, it would be advisable to prominently note this issue in the JAX docs!

@jakevdp
Copy link
Collaborator

jakevdp commented Jul 22, 2024

Hi - I think the larger issue here is the lack of clarity around the meaning of each precision setting on each platform. This is tracked in #18934, but there's not been much progress because it's complicated. To try to answer your questions:

Is this a bug, or is such low precision, which is clearly unusable for any serious numerical calculations, really considered acceptable for the "high" precision setting?

I think this is working as intended. Accelerators have different hardware matmul implementations, but in general they favor speed over accuracy with the default setting.

Is "high" the current default across all devices?

No, default is the current default across all devices. What that maps to in terms of what actual operations the device performs differs from device to device. On CPU, default is the same as highest. On other backends, default is the same as bfloat16.

Is it clear why there is such a big difference between "high" and "highest" on an A100 while these levels give the same performance on an RTX 2080?

The meaning of "high" and "highest" is dependent on the hardware you are using.

This huge difference in accuracy complicates writing code that's portable across different devices. If the observation for the A100 is not considered a bug, is there any other solution other than always globally applying the "highest" precision setting? And if it's not a bug, it would be advisable to prominently note this issue in the JAX docs!

Agreed – this is something that we hope to address, but it's difficult. #18934 mentions these issues explicitly, and is probably the relevant issue to follow.

@bwohlberg
Copy link

Understood. Thanks for the thorough response!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working NVIDIA GPU Issues specific to NVIDIA GPUs
Projects
None yet
Development

No branches or pull requests

5 participants