-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RuntimeError: UNKNOWN: Failed to determine best cudnn convolution algorithm #8506
Comments
I couldn't repro this on my environment, it's likely RTX 3080 specific. I'm asking our GPU experts to take a look
|
Yep, this is ampere-specific, and I was able to repro on an A6000 using the previous release. Yesterday's release of jaxlib 0.1.74 fixes it on my machine: can you try that? |
The latest release doesn't fix it
$ nvidia-smi +-----------------------------------------------------------------------------+ |
What version of CuDNN do you have installed? |
Sorry I hadnt seen your reply I have the latest cuda and cudnn version, 11.5 and 8.3.1 I have tried building the jax project for my computer however that hasnt worked either Any other suggestions to try? |
@pseudo-rnd-thoughts: |
Sadly that haven't fixed it either, same error about cudnn convolutions This is really strange as a couple of people have had this bug but have all got their code working. Im not sure what is strange on my system. == Version information == Error
== Jax / Flax Code
== Tensorflow code
Thanks |
@pseudo-rnd-thoughts Go to the folder /usr/local and check what cuda installation you have installed (in my case it is cuda-11.3 as I work with the docker image
Replace |
Seeing the same issue on a Quadro T2000, tried the various fixes above and none worked. == Version information |
Fix it, it is a memory allocation issue like suggested below however different I found this previous discussion that had a very similar problem to mine The discussion noted the way that Jax allocates memory, which by default is 90% on the first JAX operation which for us was the convolution operation. @rems75 does this fix the issue for you? |
@pseudo-rnd-thoughts I think the fix is that we need to have a minimum absolute amount of GPU RAM that we reserve for CuDNN. How much GPU RAM do you have? Is 0.7 the largest value that works? e.g., does, say, 0.8 work? |
@hawkinsp I have Nvidia 3080 with 10Gb RAM
I did a bit of testing: 80% and 85% are good while 90% causes the crash. @hawkinsp do you have any other questions? |
@pseudo-rnd-thoughts No, that seems roughly in line with what I expect. You have 10016MiB, of which JAX claims 90% (9014MiB). Your system processes claim another 300MiB, so (9314MiB), and there's only ~700MiB left for CuDNN. This is apparently not enough. I think the way to fix this is for JAX to ensure that at least say, 1GiB is left free after its allocation for CuDNN to work. I don't know what the right value is for "1GiB", but clearly ~700MiB is too low. |
@hawkinsp Thanks, I was imagining that the cudnn memory usage would be within the JAX preallocated amount. That makes a lot of sense now. |
Worked for me as well. Cool. Numbers are different in my case. This is the GPU memory with XLA_PYTHON_CLIENT_MEM_FRACTION=0.8: Note that this is through WSL2 on a Laptop running Windows 11. |
Hi everyone, I had the exact same issue described above. I am running on WSL2 on Windows 10. I installed CUDA and CuDNN and then installed Pre-updated memory fraction
Post-updated memory fraction
I am very new to any sort of collaboration on repositories, so apologies if my etiquette is somewhat off, but I was wondering whether this had any updates? Any "best practice" ways to correct this? I am currently setting Also, slightly unrelated, but what would be the best way to keep up with updates to this repository? I will be using Jax pretty religiously to build SVGP models as I love its flexibility, and so would like to keep up-to-date. Thanks for any help! |
Running convolutional layers seems to cause an error that Jax does not know what cudnn optimisation algorithm to use
This error appears to be Jax only as I have replicated the code with TensorFlow and no error occurs
My jax version is 0.2.24 and jaxlib version is 0.1.74+cuda11.cudnn82 with a Nvidia 3080
The example is taken from the flax readme (https://github.com/google/flax)
The bug appears to be only for convolutions as the error does not occur for the MLP example
I haven't been able to replicate this error as I don't have another GPU to use
I found this similar issue from someone who uses a 3080 like me (#7953)
The text was updated successfully, but these errors were encountered: