Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(A6000 GPU) Matrix Multiplication: Large Changes in Numerical Stabilities Between Runs #9784

Closed
3 tasks done
HanGuo97 opened this issue Mar 7, 2022 · 4 comments
Closed
3 tasks done
Labels
bug Something isn't working

Comments

@HanGuo97
Copy link

HanGuo97 commented Mar 7, 2022

I noticed that when I re-ordered the matrix multiplication order, the numerical stabilities can vary wildly (4 orders of magnitude). Further, even with the same seed, whether an operation is numerically stable can vary (significantly) between runs.

The following were observed when running on an A6000 GPU. However, this appears fine when running on CPU, V100, or Google Colab (K80).

Please:

  • Check for duplicate issues.
  • Provide a complete example of how to reproduce the bug, wrapped in triple backticks like this:
import jax
# jax.config.update("jax_platform_name", "cpu")

import numpy as np
import jax.numpy as jnp
d = 97
for r in range(20, 40):
    key = jax.random.PRNGKey(r)

    key, subkey = jax.random.split(key)
    a = jax.random.normal(key)

    key, subkey = jax.random.split(key)
    D = jax.random.normal(key, shape=[d, r])

    key, subkey = jax.random.split(key)
    v = jax.random.normal(key, shape=[d])

    diff = (D @ (D.T @ v) -  D @ D.T @ v)
    print(
        f"Rank={r}: "
        f"a: {a:.3f}, \t"
        f"D(norm) {jnp.linalg.norm(D):.3f}, \t"
        f"v(norm) {jnp.linalg.norm(v):.2f}, \t"
        f"Error: {jnp.linalg.norm(diff):.3e}")
  • If applicable, include full error messages/tracebacks.
# Output from Run 1
Rank=20: a: 0.753, 	D(norm) 43.465, 	v(norm) 10.75, 	Error: 1.057e-01   <---
Rank=21: a: 0.963, 	D(norm) 44.669, 	v(norm) 9.45, 	Error: 6.323e-05
Rank=22: a: 0.269, 	D(norm) 46.610, 	v(norm) 10.44, 	Error: 8.321e-05
Rank=23: a: -0.232, 	D(norm) 47.038, 	v(norm) 10.07, 	Error: 7.781e-05
Rank=24: a: -1.412, 	D(norm) 48.142, 	v(norm) 10.59, 	Error: 8.764e-05
Rank=25: a: 0.038, 	D(norm) 48.904, 	v(norm) 8.98, 	Error: 7.407e-05
Rank=26: a: -0.816, 	D(norm) 49.992, 	v(norm) 10.28, 	Error: 9.853e-05
Rank=27: a: 1.559, 	D(norm) 51.330, 	v(norm) 10.49, 	Error: 6.900e-05
Rank=28: a: 0.204, 	D(norm) 52.076, 	v(norm) 9.70, 	Error: 8.498e-05
Rank=29: a: -0.143, 	D(norm) 52.809, 	v(norm) 10.20, 	Error: 9.320e-05
Rank=30: a: 0.024, 	D(norm) 53.732, 	v(norm) 9.13, 	Error: 8.484e-05
Rank=31: a: 0.889, 	D(norm) 55.202, 	v(norm) 10.40, 	Error: 9.107e-05
Rank=32: a: 0.427, 	D(norm) 56.583, 	v(norm) 10.34, 	Error: 8.759e-05
Rank=33: a: 0.398, 	D(norm) 57.486, 	v(norm) 10.12, 	Error: 8.744e-05
Rank=34: a: -1.186, 	D(norm) 56.714, 	v(norm) 9.43, 	Error: 8.845e-05
Rank=35: a: -0.012, 	D(norm) 58.444, 	v(norm) 9.27, 	Error: 1.707e-01   <---
Rank=36: a: 0.712, 	D(norm) 58.687, 	v(norm) 9.36, 	Error: 7.037e-05
Rank=37: a: -0.177, 	D(norm) 60.725, 	v(norm) 9.32, 	Error: 8.887e-05
Rank=38: a: 0.247, 	D(norm) 61.139, 	v(norm) 10.34, 	Error: 8.436e-05
Rank=39: a: -1.059, 	D(norm) 62.312, 	v(norm) 9.58, 	Error: 9.310e-05
# Output from another Run
Rank=20: a: 0.753, 	D(norm) 43.465, 	v(norm) 10.75, 	Error: 1.057e-01   <---
Rank=21: a: 0.963, 	D(norm) 44.669, 	v(norm) 9.45, 	Error: 6.323e-05
Rank=22: a: 0.269, 	D(norm) 46.610, 	v(norm) 10.44, 	Error: 8.321e-05
Rank=23: a: -0.232, 	D(norm) 47.038, 	v(norm) 10.07, 	Error: 7.781e-05
Rank=24: a: -1.412, 	D(norm) 48.142, 	v(norm) 10.59, 	Error: 8.764e-05
Rank=25: a: 0.038, 	D(norm) 48.904, 	v(norm) 8.98, 	Error: 7.407e-05
Rank=26: a: -0.816, 	D(norm) 49.992, 	v(norm) 10.28, 	Error: 9.853e-05
Rank=27: a: 1.559, 	D(norm) 51.330, 	v(norm) 10.49, 	Error: 6.900e-05
Rank=28: a: 0.204, 	D(norm) 52.076, 	v(norm) 9.70, 	Error: 1.292e-01   <---
Rank=29: a: -0.143, 	D(norm) 52.809, 	v(norm) 10.20, 	Error: 9.320e-05
Rank=30: a: 0.024, 	D(norm) 53.732, 	v(norm) 9.13, 	Error: 8.484e-05
Rank=31: a: 0.889, 	D(norm) 55.202, 	v(norm) 10.40, 	Error: 9.107e-05
Rank=32: a: 0.427, 	D(norm) 56.583, 	v(norm) 10.34, 	Error: 8.759e-05
Rank=33: a: 0.398, 	D(norm) 57.486, 	v(norm) 10.12, 	Error: 1.684e-01   <---
Rank=34: a: -1.186, 	D(norm) 56.714, 	v(norm) 9.43, 	Error: 8.845e-05
Rank=35: a: -0.012, 	D(norm) 58.444, 	v(norm) 9.27, 	Error: 7.667e-05
Rank=36: a: 0.712, 	D(norm) 58.687, 	v(norm) 9.36, 	Error: 7.037e-05
Rank=37: a: -0.177, 	D(norm) 60.725, 	v(norm) 9.32, 	Error: 8.887e-05
Rank=38: a: 0.247, 	D(norm) 61.139, 	v(norm) 10.34, 	Error: 8.436e-05
Rank=39: a: -1.059, 	D(norm) 62.312, 	v(norm) 9.58, 	Error: 9.310e-05
# Output from another Run
Rank=20: a: 0.753, 	D(norm) 43.465, 	v(norm) 10.75, 	Error: 1.057e-01   <---
Rank=21: a: 0.963, 	D(norm) 44.669, 	v(norm) 9.45, 	Error: 6.323e-05
Rank=22: a: 0.269, 	D(norm) 46.610, 	v(norm) 10.44, 	Error: 8.321e-05
Rank=23: a: -0.232, 	D(norm) 47.038, 	v(norm) 10.07, 	Error: 7.781e-05
Rank=24: a: -1.412, 	D(norm) 48.142, 	v(norm) 10.59, 	Error: 1.489e-01   <---
Rank=25: a: 0.038, 	D(norm) 48.904, 	v(norm) 8.98, 	Error: 7.407e-05
Rank=26: a: -0.816, 	D(norm) 49.992, 	v(norm) 10.28, 	Error: 9.853e-05
Rank=27: a: 1.559, 	D(norm) 51.330, 	v(norm) 10.49, 	Error: 6.900e-05
Rank=28: a: 0.204, 	D(norm) 52.076, 	v(norm) 9.70, 	Error: 8.498e-05
Rank=29: a: -0.143, 	D(norm) 52.809, 	v(norm) 10.20, 	Error: 9.320e-05
Rank=30: a: 0.024, 	D(norm) 53.732, 	v(norm) 9.13, 	Error: 8.484e-05
Rank=31: a: 0.889, 	D(norm) 55.202, 	v(norm) 10.40, 	Error: 9.107e-05
Rank=32: a: 0.427, 	D(norm) 56.583, 	v(norm) 10.34, 	Error: 8.759e-05
Rank=33: a: 0.398, 	D(norm) 57.486, 	v(norm) 10.12, 	Error: 8.744e-05
Rank=34: a: -1.186, 	D(norm) 56.714, 	v(norm) 9.43, 	Error: 8.845e-05
Rank=35: a: -0.012, 	D(norm) 58.444, 	v(norm) 9.27, 	Error: 7.667e-05
Rank=36: a: 0.712, 	D(norm) 58.687, 	v(norm) 9.36, 	Error: 7.037e-05
Rank=37: a: -0.177, 	D(norm) 60.725, 	v(norm) 9.32, 	Error: 8.887e-05
Rank=38: a: 0.247, 	D(norm) 61.139, 	v(norm) 10.34, 	Error: 8.436e-05
Rank=39: a: -1.059, 	D(norm) 62.312, 	v(norm) 9.58, 	Error: 9.310e-05
@HanGuo97 HanGuo97 added the bug Something isn't working label Mar 7, 2022
@HanGuo97 HanGuo97 changed the title (GPU) Matrix Multiplication: Large Changes in Numerical Stabilities Between Runs (A6000 GPU) Matrix Multiplication: Large Changes in Numerical Stabilities Between Runs Mar 7, 2022
@hawkinsp
Copy link
Member

hawkinsp commented Mar 8, 2022

I'm curious: is this a JAX-specific problem? I ask because all JAX is doing here most likely is calling Cublas, which is the same library used by, say, TensorFlow or PyTorch to perform matrix multiplication on GPU.

One guess I have is perhaps this is related to the use of TensorCore math on Ampere-series GPUs.

@HanGuo97
Copy link
Author

HanGuo97 commented Mar 8, 2022

Interestingly enough, this problem seems largely resolved when I upgraded the Docker image used.

Notably, I was using a slightly older CUDA (11.1) and cuDNN version, inherited from a PyTorch base docker image. Changed this to NVIDIA's image with CUDA 11.4 and corresponding cuDNN, and numerical stabilities seem more normal now.

@hawkinsp
Copy link
Member

hawkinsp commented Mar 8, 2022

That's good! Can we declare this working as intended?

@HanGuo97
Copy link
Author

HanGuo97 commented Mar 8, 2022

Yes, closing it now. Thanks again for the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants