-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jax crashes after upgrade #15260
Comments
We don't ourselves support WSL, but I expect it should work. Do any CUDA 12 apps work in this setup? |
I successfully compiled and ran one cuda sample(transpose) and one cudnn sample(mnistCUDNN) with no issues. As far as I can tell torch runs fine as well. When I used CUDA 11.8+cudnn8.6 Jax worked just fine. |
I managed to get it working with CUDA 11, so problem is definitely localized to CUDA 12. Maybe it has something to do with me using Pascal GPU? |
Unfortunately we don't ourselves test or support Windows or WSL2, but if you can figure out the problem we certainly would accept contributions to fix it! |
Hi,
https://github.com/google/jax#installation apt update apt -y install python3-pip python3
Window Edition: Windows 11 Pro Studio Driver
https://github.com/google/jax#installation apt update apt -y install python3-pip python3
Window Edition: Windows 11 Pro Studio Driver |
@mikelaud just tried same sequence with fresh install of Ubuntu 20.04. I receive the same error. Not sure what's the issue. I have 531.41 driver(though game ready), Windows 11 Pro 22H2 with build version being 22621.1413. And my GPU is 1080 Ti. |
Ok, now I know that jax is broken on Pascal with WSL with any version of CUDA(less features are broken with CUDA 11, more with CUDA 12). I tested a couple of commits and figured that after XLA got upgraded from 0f31407ee498e6dba242d03f8d382ebcfcc61790 to 79ca8d03c296ede04dc9a86ce9dde79ed909dda8 all my issues started. Not sure which version of XLA is to blame, not sure how to build jax against different versions of XLA. |
@terafo Well, it looks like one should create a bug within the XLA framework. After the fix on the XLA side, it will be possible to raise the version for JAX. |
With recent 0.4.9 upgrade the issue is resolved since it uses more recent XLA version, in which this issue is fixed. |
Description
All operations just outright crash after upgrade. For example
Outputs this error:
This, at least outputs array(fallback path, I presume). If I try
I get this error:
I've got this issue when I tried to use both local and pip cuda.
What jax/jaxlib version are you using?
0.4.7
Which accelerator(s) are you using?
GPU
Additional system info
Python 3.10 in WSL with CUDNN 8.8.1 and CUDA 12.1
NVIDIA GPU info
The text was updated successfully, but these errors were encountered: