New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
"NCHW" data_format in Conv not working with latest CUDA #83
Comments
Hey @mil-ad do you have cuDNN installed? I think you need to install this separately to GPU drivers.. I've tested your snippet in Google Colab on an Nvidia P100 GPU and it seems to work fine: https://colab.research.google.com/gist/tomhennigan/6fd38842b05d46b8418cf44a4083be9d/nchw-test.ipynb |
I do have cuDN installed but perhaps similar to cuda I should explicitly point Jax to it? |
Also the |
See google/jax#4920 . Can you try setting My guess about |
Thanks @hawkinsp that does make it go further although still fails:
|
Lets keep the discussion in the JAX bug, I think this is not Haiku specific. |
I'm not able to use the
NCHW
data format in conv layers:The snippet above works fine on the CPU but on the GPU gives tensorflow-style spew of errors below. The problem goes away if I change
data_format
toNHWC
. I'm running pretty recent versions of nvidia driver and cuda and the same snippet seems to run on older versions (according to a few people I sent it to) so pretty sure it's related to those. My versions are:Error:
The text was updated successfully, but these errors were encountered: