-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
No kernel image is available for execution on the device #5723
Comments
Can you share the output of |
Added now. Thanks for flagging. |
I'm not sure why this happens, but I see the same thing if I have driver version 450.102.04 but CUDA version 11.1. According to NVidia these two versions should work together (https://docs.nvidia.com/deploy/cuda-compatibility/index.html). I don't know why they aren't. I suggest either upgrading your driver or installing an older CUDA release. |
Hello, I am experiencing a similar problem, even with a newer version of jaxlib
In my case,
and `Fri Mar 12 13:54:27 2021 +-----------------------------------------------------------------------------+ The actual error that I am getting is the following:
Note that I manually set the environmental variable to the CUDA path with Thanks for helping |
We've confirmed this issue is due to using too new a version of CUDA with too old a driver version. If you see this issue, the workaround is either to use an older CUDA release or a newer NVidia driver. We may be able to work around at the JAX level also in the future. |
I encountered this problem on an HPC of which I'm not an admin, and I set |
Setting the XLA flag works for me as well. But it comes with the warning
This sounds undesireable to not be multithreading. Can we get a more permanent solution? |
I'm also seeing this issue. I'm on NixOS 20.09.3301.42809feaa9f, jaxlib 0.1.71, and here's my nvidia-smi:
Based on the NVIDIA docs, it seems like these two versions should be compatible. Passing |
I just upgraded to NixOS 21.05.2796.110a2c9ebbf to get driver version 470.57.02, and the issue has gone away. |
I am using Centos 7, and I was having the issue that is mentioned here, and the problem was solved after upgrading the nvidia driver from 460 to 470.74. |
I'm having the same issue on an HPC using a singularity container. I'm not the admin so I can't update the nvidia driver. If a jax workaround that still allows multithreaded compilation is possible, that would be awesome! |
Getting the same error and the XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1 doesn't work for me and introduces a new error. jax version
nvcc and nvidia-smi
Before XLA flag
...
After including flag
without XLA flag with TF_CPP_MIN_LOG_LOVEL=0 as issue #7118 returns the same export XLA_PYTHON_CLIENT_PREALLOCATE=false without XLA flag (solution in #7118 also returns the same |
I have the same problem on RTX 3090. Running with force compilation parallel =1 gives another error (could be related to #8506):
Edit: |
Hello, can you tell me specifically how to install this package? |
@qinggeduoqing Install everything with conda-forge e.g |
@lkhphuc Thank you so much for your reply, let me try this way... |
I get error with the following command,
The stack trace below excludes JAX-internal frames. Jax can see the gpus though,
I installed jaxlib 1.77.0 and jax 0.2.28 from source. I am using cuda 11.5 and cudnn 8.3.0. I made sure that the PATH env variable is setup properly and that the python session is loading the currect cuda libraries. I'm not sure what else can be wrong. I'm running the program on an Ubuntu 20.04 with 2 GTX 1080s. |
@mshafiei Can you share the output of |
@hawkinsp sure,
|
Hmm. Interesting. I would have expected that to work. My best suggestion for something you to try to get unblocked would be to build jaxlib from source, explicitly opting in for the CUDA capability for your device. (https://jax.readthedocs.io/en/latest/developer.html). There's an option to specify a list of CUDA compute capabilities to the |
@hawkinsp I am actually building jaxlib from source and passing the cuda specifications as below,
Are these flags what you were referring to? |
In my case, making sure that When
and
I was getting the error above, and I could only work around it by setting |
@lee-van-oetz is this fixed now? |
I'm pretty sure this is fixed in recent jaxlib releases. We added code to jaxlib that falls back to not using |
Hi, I am also facing this problem. I use the following :
After I installed all the packages related to C++ on VS2019, I installed Cuda Toolkit.
I even ran Cuda 11.7, 11.8 and 12.0 on VS2019 and VS2022, But the error still exists. please help me out , |
Anyone else still struggling with this? I've tried dozens of combinations of nvidia drivers, ubuntu's nvidia-cuda-toolkit, installing different versions of cuda from nvidia's website, jaxlib and jax, nothing has solved the problem. I'm using nvidia-smi nvcc -V jaxlib==0.4.20+cuda11.cudnn86 print(jax.devices()) shows the GPU But trying to use it results in FWIW, I had jax working fine just a few days ago, might be caused by a recent update to 23.04. |
@deoxyribose I think the problem is that we are building for GPUs with SM version 5.2 at a minimum: Line 68 in 961ba3c
but your GPU appears to have SM version 5.0. The fix is to build jaxlib yourself, explicitly specifying your SM version. Try:
We've actually never shipped support for that model of GPU. |
This improves compatibility with older Maxwell cards, and it probably doesn't matter a whole lot for performance. See: #5723 (comment) PiperOrigin-RevId: 584641856
@deoxyribose #18644 will add sm_50 support to the next jaxlib release. |
This improves compatibility with older Maxwell cards, and it probably doesn't matter a whole lot for performance. See: #5723 (comment) PiperOrigin-RevId: 584641856
@hawkinsp Thanks for the quick reply! I tried removing and installing cuda with I'm not sure if this is the expected location:
I tried with |
This improves compatibility with older Maxwell cards, and it probably doesn't matter a whole lot for performance. See: #5723 (comment) PiperOrigin-RevId: 584641856
This improves compatibility with older Maxwell cards, and it probably doesn't matter a whole lot for performance. See: #5723 (comment) PiperOrigin-RevId: 584641856
This improves compatibility with older Maxwell cards, and it probably doesn't matter a whole lot for performance. See: #5723 (comment) PiperOrigin-RevId: 584641856
This improves compatibility with older Maxwell cards, and it probably doesn't matter a whole lot for performance. See: #5723 (comment) PiperOrigin-RevId: 585967281
I have installed through
pip3 install --upgrade jax jaxlib==0.1.61+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
nvcc --version
showsCUDA 11.1 is at
/usr/local/cuda-11.1
and yet I am getting
Output of
nvidia-smi
:when trying the quickstart example.
The text was updated successfully, but these errors were encountered: