You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The code will result in the following error when compiling:
Traceback (most recent call last):
File "/****/bug.py", line 12, in<module>
jax.jit(vmap(f))(adj, mat)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch CUDA kernel: triton_gemm_dot_0 with block dimensions: 128x1x1 and grid dimensions: 4x1x102400 and shared memory size: 65536: CUDA_ERROR_INVALID_VALUE: invalid argument
Here adj is an adjacency matrix of type bool and mat is just a random matrix.
Setting adj to float or avoid using @ by using a combination of vmap and jnp.sum could solve this problem.
What jax/jaxlib version are you using?
jax v0.4.11, jaxlib 0.4.11+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.10, Linux
NVIDIA GPU info
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.41.03 Driver Version: 530.41.03 CUDA Version: 12.1 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA GeForce RTX 3090 Off| 00000000:17:00.0 Off | N/A |
| 35% 36C P8 27W / 350W| 19MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA GeForce RTX 3090 Off| 00000000:B3:00.0 Off | N/A |
| 37% 44C P8 21W / 350W| 6MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+----------------------+----------------------+
+---------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=======================================================================================|
+---------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
Here is an minimum example to reproduce the bug.
Tested using a single RTX 3090.
The code will result in the following error when compiling:
Here adj is an adjacency matrix of type bool and mat is just a random matrix.
Setting adj to float or avoid using
@
by using a combination ofvmap
andjnp.sum
could solve this problem.What jax/jaxlib version are you using?
jax v0.4.11, jaxlib 0.4.11+cuda12.cudnn88
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.10, Linux
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: