Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to launch CUDA kernel when multiplying bool matrices with large batch size. #16286

Closed
BillHuang2001 opened this issue Jun 7, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@BillHuang2001
Copy link

Description

Here is an minimum example to reproduce the bug.
Tested using a single RTX 3090.

import jax
from jax import jit, vmap
import jax.numpy as jnp

@jit
def f(adj, mat):
    return adj @ mat / jnp.sum(adj, axis=1)[:, jnp.newaxis]

adj = jnp.ones((1024 * 100, 10, 10), dtype=bool)
mat = jnp.ones((1024 * 100, 10, 100), dtype=float)

jax.jit(vmap(f))(adj, mat)

The code will result in the following error when compiling:

Traceback (most recent call last):
  File "/****/bug.py", line 12, in <module>
    jax.jit(vmap(f))(adj, mat)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch CUDA kernel: triton_gemm_dot_0 with block dimensions: 128x1x1 and grid dimensions: 4x1x102400 and shared memory size: 65536: CUDA_ERROR_INVALID_VALUE: invalid argument

Here adj is an adjacency matrix of type bool and mat is just a random matrix.
Setting adj to float or avoid using @ by using a combination of vmap and jnp.sum could solve this problem.

What jax/jaxlib version are you using?

jax v0.4.11, jaxlib 0.4.11+cuda12.cudnn88

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.10, Linux

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.41.03              Driver Version: 530.41.03    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce RTX 3090         Off| 00000000:17:00.0 Off |                  N/A |
| 35%   36C    P8               27W / 350W|     19MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090         Off| 00000000:B3:00.0 Off |                  N/A |
| 37%   44C    P8               21W / 350W|      6MiB / 24576MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
@hawkinsp
Copy link
Member

Thanks for the report, I filed an XLA bug.

If you need a workaround until it is fixed, try setting the environment variable XLA_FLAGS=--xla_gpu_enable_triton_gemm=false

@hawkinsp
Copy link
Member

openxla/xla#3530 fixed this, and should be in the next jaxlib release.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants