You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a follow up for my discussion : #22212
I was able to make it work, but inside a bigger jitted function .. it fails with this error
RuntimeError: Fetching value for jax.Arraythat spans non-addressable (non process local) devices is not possible. You can usejax.experimental.multihost_utils.process_allgatherto print the global array or use.addressable_shardsmethod of jax.Array to inspect the addressable (process local) shards.
File "/gpfsdswork/projects/rech/tkc/commun/JaxPM/scripts/fastpm_jaxdecomp_cic_paint.py", line 506, in <module>
field = forward_fn(z, kvec, a=1.)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
out_flat = pjit_p.bind(*args_flat, **params)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/core.py", line 2834, in bind
return self.bind_with_trace(top_trace, args, params)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive
return primitive.impl(*tracers, **params)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1635, in _pjit_call_impl
return xc._xla.pjit(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1614, in call_impl_cache_miss
out_flat, compiled = _pjit_call_impl_python(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1536, in _pjit_call_impl_python
compiled = _resolve_and_lower(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1503, in _resolve_and_lower
lowered = _pjit_lower(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1645, in _pjit_lower
return _pjit_lower_cached(*args, **kwargs)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1685, in _pjit_lower_cached
return pxla.lower_sharding_computation(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2253, in lower_sharding_computation
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2017, in _cached_lowering_to_hlo
lowering_result = mlir.lower_jaxpr_to_module(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 952, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1393, in lower_jaxpr_to_fun
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1393, in <listcomp>
consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 252, in ir_constants
out = handler(val)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/array.py", line 1010, in _array_mlir_constant_handler
return mlir.ir_constants(val._value)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
return func(*args, **kwargs)
File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.
I understood that inside a jitted function, the compiler should not show this error.
But but what is puzzeling me is that one of the HLO constants is a non addressable array
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.30
jaxlib: 0.4.30
numpy: 2.0.0
python: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:38:57) [GCC 10.3.0]
jax.devices (4 total, 4 local): [cuda(id=0) cuda(id=1) cuda(id=2) cuda(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='r9i2n4', release='5.14.0-284.55.1.el9_2.x86_64', version='#1 SMP PREEMPT_DYNAMIC Mon Feb 19 16:57:59 EST 2024', machine='x86_64')
$ nvidia-smi
Mon Jul 1 21:50:56 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Tesla V100-SXM2-32GB On | 00000000:1A:00.0 Off | 0 |
| N/A 44C P0 61W / 300W | 311MiB / 32768MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 1 Tesla V100-SXM2-32GB On | 00000000:1C:00.0 Off | 0 |
| N/A 45C P0 62W / 300W | 311MiB / 32768MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 2 Tesla V100-SXM2-32GB On | 00000000:88:00.0 Off | 0 |
| N/A 46C P0 61W / 300W | 311MiB / 32768MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
| 3 Tesla V100-SXM2-32GB On | 00000000:8A:00.0 Off | 0 |
| N/A 46C P0 60W / 300W | 311MiB / 32768MiB | 1% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 672556 C ...ech/tkc/commun/venv/v100/bin/python 308MiB |
| 1 N/A N/A 672556 C ...ech/tkc/commun/venv/v100/bin/python 308MiB |
| 2 N/A N/A 672556 C ...ech/tkc/commun/venv/v100/bin/python 308MiB |
| 3 N/A N/A 672556 C ...ech/tkc/commun/venv/v100/bin/python 308MiB |
+-----------------------------------------------------------------------------------------+
The text was updated successfully, but these errors were encountered:
Description
Hello,
This is a follow up for my discussion : #22212
I was able to make it work, but inside a bigger jitted function .. it fails with this error
RuntimeError: Fetching value for
jax.Arraythat spans non-addressable (non process local) devices is not possible. You can use
jax.experimental.multihost_utils.process_allgatherto print the global array or use
.addressable_shardsmethod of jax.Array to inspect the addressable (process local) shards.
I suspect that there is a bug.
in here https://github.com/google/jax/blob/9653f58fa291895ccd15bb20dac00495367e1980/jax/_src/interpreters/mlir.py#L1399
This is the call stack
I understood that inside a jitted function, the compiler should not show this error.
But but what is puzzeling me is that one of the HLO constants is a non addressable array
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: