Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax crashes after upgrade #15260

Closed
terafo opened this issue Mar 28, 2023 · 9 comments
Closed

Jax crashes after upgrade #15260

terafo opened this issue Mar 28, 2023 · 9 comments
Labels
bug Something isn't working contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. NVIDIA GPU Issues specific to NVIDIA GPUs type:support Windows Issues related to JAX on Microsoft Windows

Comments

@terafo
Copy link

terafo commented Mar 28, 2023

Description

All operations just outright crash after upgrade. For example

import jax
jax.numpy.zeros((5,5))

Outputs this error:

2023-03-28 15:01:11.001453: W external/xla/xla/stream_executor/cuda/cuda_dnn.cc:397] There was an error before creating cudnn handle: cudaErrorNotSupported : operation not supported
Array([0., 0., 0., 0., 0.], dtype=float32)

This, at least outputs array(fallback path, I presume). If I try

import jax
jax.random.normal(jax.random.PRNGKey(0),(5,5))

I get this error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/random.py", line 137, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 561, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
    return seed(seeds)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 817, in threefry_seed
    k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 166, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 248, in cache_miss
    outs, out_flat, out_tree, args_flat = _python_pjit_helper(
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 195, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 2591, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 362, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 816, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1269, in _pjit_call_impl
    return compiled.unsafe_call(*args)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1934, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: stream is uninitialized or in an error state

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/random.py", line 137, in PRNGKey
    key = prng.seed_with_impl(impl, seed)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 270, in seed_with_impl
    return random_seed(seed, impl=impl)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 561, in random_seed
    return random_seed_p.bind(seeds_arr, impl=impl)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 359, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 362, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/core.py", line 816, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 573, in random_seed_impl
    base_arr = random_seed_impl_base(seeds, impl=impl)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 578, in random_seed_impl_base
    return seed(seeds)
  File "/home/terafo/jt/.venv/lib/python3.10/site-packages/jax/_src/prng.py", line 817, in threefry_seed
    k2 = convert(jnp.bitwise_and(seed, np.uint32(0xFFFFFFFF)))
jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: stream is uninitialized or in an error state

I've got this issue when I tried to use both local and pip cuda.

What jax/jaxlib version are you using?

0.4.7

Which accelerator(s) are you using?

GPU

Additional system info

Python 3.10 in WSL with CUDNN 8.8.1 and CUDA 12.1

NVIDIA GPU info

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.41.03              Driver Version: 531.41       CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                  Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf            Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA GeForce GTX 1080 Ti      On | 00000000:09:00.0  On |                  N/A |
| 20%   46C    P8               13W / 250W|  10019MiB / 11264MiB |      2%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A        22      G   /Xwayland                                 N/A      |
|    0   N/A  N/A       487      C   /python3.10                               N/A      |
+---------------------------------------------------------------------------------------+
@terafo terafo added the bug Something isn't working label Mar 28, 2023
@hawkinsp hawkinsp added the NVIDIA GPU Issues specific to NVIDIA GPUs label Mar 28, 2023
@hawkinsp
Copy link
Member

We don't ourselves support WSL, but I expect it should work. Do any CUDA 12 apps work in this setup?

@terafo
Copy link
Author

terafo commented Mar 28, 2023

I successfully compiled and ran one cuda sample(transpose) and one cudnn sample(mnistCUDNN) with no issues. As far as I can tell torch runs fine as well. When I used CUDA 11.8+cudnn8.6 Jax worked just fine.

@terafo
Copy link
Author

terafo commented Mar 28, 2023

I managed to get it working with CUDA 11, so problem is definitely localized to CUDA 12. Maybe it has something to do with me using Pascal GPU?

@hawkinsp hawkinsp added contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. Windows Issues related to JAX on Microsoft Windows labels Mar 29, 2023
@hawkinsp
Copy link
Member

Unfortunately we don't ourselves test or support Windows or WSL2, but if you can figure out the problem we certainly would accept contributions to fix it!

@mikelaud
Copy link

mikelaud commented Mar 29, 2023

Hi,
just tested jax + wsl2 and found no problem:

  1. WSL2 + Ubuntu-20.04 + jax-0.4.7

https://github.com/google/jax#installation
wsl --install -d Ubuntu-20.04
sudo su

apt update
apt -y full-upgrade

apt -y install python3-pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

python3

import jax
jax.numpy.zeros((5,5))
Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32)
jax.default_backend()
'gpu'

Window Edition: Windows 11 Pro
Version: 21H2
OS build: 22000.1696

Studio Driver
Version 531.41
NVIDIA GeForce RTX 3090

  1. WSL2 + Ubuntu-22.04 + jax-0.4.7

https://github.com/google/jax#installation
wsl --install -d Ubuntu-22.04
sudo su

apt update
apt -y full-upgrade

apt -y install python3-pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

python3

import jax
jax.numpy.zeros((5,5))
Array([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32)
jax.default_backend()
'gpu'

Window Edition: Windows 11 Pro
Version: 21H2
OS build: 22000.1696

Studio Driver
Version 531.41
NVIDIA GeForce RTX 3090

@terafo
Copy link
Author

terafo commented Mar 29, 2023

@mikelaud just tried same sequence with fresh install of Ubuntu 20.04. I receive the same error. Not sure what's the issue. I have 531.41 driver(though game ready), Windows 11 Pro 22H2 with build version being 22621.1413. And my GPU is 1080 Ti.
Update: switched to studio driver, still same error. I feel cursed.

@terafo
Copy link
Author

terafo commented Mar 30, 2023

Ok, now I know that jax is broken on Pascal with WSL with any version of CUDA(less features are broken with CUDA 11, more with CUDA 12). I tested a couple of commits and figured that after XLA got upgraded from 0f31407ee498e6dba242d03f8d382ebcfcc61790 to 79ca8d03c296ede04dc9a86ce9dde79ed909dda8 all my issues started. Not sure which version of XLA is to blame, not sure how to build jax against different versions of XLA.

@mikelaud
Copy link

@terafo Well, it looks like one should create a bug within the XLA framework. After the fix on the XLA side, it will be possible to raise the version for JAX.

@terafo
Copy link
Author

terafo commented May 10, 2023

With recent 0.4.9 upgrade the issue is resolved since it uses more recent XLA version, in which this issue is fixed.

@terafo terafo closed this as completed May 10, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. NVIDIA GPU Issues specific to NVIDIA GPUs type:support Windows Issues related to JAX on Microsoft Windows
Projects
None yet
Development

No branches or pull requests

3 participants