Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

non-addressable data access inside a jitted function #22218

Closed
ASKabalan opened this issue Jul 1, 2024 · 2 comments
Closed

non-addressable data access inside a jitted function #22218

ASKabalan opened this issue Jul 1, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@ASKabalan
Copy link

Description

Hello,

This is a follow up for my discussion : #22212
I was able to make it work, but inside a bigger jitted function .. it fails with this error

RuntimeError: Fetching value for jax.Arraythat spans non-addressable (non process local) devices is not possible. You can usejax.experimental.multihost_utils.process_allgatherto print the global array or use.addressable_shardsmethod of jax.Array to inspect the addressable (process local) shards.

I suspect that there is a bug.
in here https://github.com/google/jax/blob/9653f58fa291895ccd15bb20dac00495367e1980/jax/_src/interpreters/mlir.py#L1399

This is the call stack


  File "/gpfsdswork/projects/rech/tkc/commun/JaxPM/scripts/fastpm_jaxdecomp_cic_paint.py", line 506, in <module>
    field = forward_fn(z, kvec, a=1.)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 327, in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
    out_flat = pjit_p.bind(*args_flat, **params)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/core.py", line 2834, in bind
    return self.bind_with_trace(top_trace, args, params)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1635, in _pjit_call_impl
    return xc._xla.pjit(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1614, in call_impl_cache_miss
    out_flat, compiled = _pjit_call_impl_python(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1536, in _pjit_call_impl_python
    compiled = _resolve_and_lower(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1503, in _resolve_and_lower
    lowered = _pjit_lower(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1645, in _pjit_lower
    return _pjit_lower_cached(*args, **kwargs)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/pjit.py", line 1685, in _pjit_lower_cached
    return pxla.lower_sharding_computation(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2253, in lower_sharding_computation
    nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2017, in _cached_lowering_to_hlo
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 952, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1393, in lower_jaxpr_to_fun
    consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1393, in <listcomp>
    consts = [ir_constants(xla.canonicalize_dtype(x)) for x in jaxpr.consts]
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 252, in ir_constants
    out = handler(val)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/array.py", line 1010, in _array_mlir_constant_handler
    return mlir.ir_constants(val._value)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper
    return func(*args, **kwargs)
  File "/gpfsdswork/projects/rech/tkc/commun/venv/v100/lib/python3.10/site-packages/jax/_src/array.py", line 602, in _value
    raise RuntimeError(
RuntimeError: Fetching value for `jax.Array` that spans non-addressable (non process local) devices is not possible. You can use `jax.experimental.multihost_utils.process_allgather` to print the global array or use `.addressable_shards` method of jax.Array to inspect the addressable (process local) shards.

I understood that inside a jitted function, the compiler should not show this error.
But but what is puzzeling me is that one of the HLO constants is a non addressable array

System info (python version, jaxlib version, accelerator, etc.)


jax:    0.4.30
jaxlib: 0.4.30
numpy:  2.0.0
python: 3.10.4 | packaged by conda-forge | (main, Mar 24 2022, 17:38:57) [GCC 10.3.0]
jax.devices (4 total, 4 local): [cuda(id=0) cuda(id=1) cuda(id=2) cuda(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='r9i2n4', release='5.14.0-284.55.1.el9_2.x86_64', version='#1 SMP PREEMPT_DYNAMIC Mon Feb 19 16:57:59 EST 2024', machine='x86_64')


$ nvidia-smi
Mon Jul  1 21:50:56 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla V100-SXM2-32GB           On  |   00000000:1A:00.0 Off |                    0 |
| N/A   44C    P0             61W /  300W |     311MiB /  32768MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  |   00000000:1C:00.0 Off |                    0 |
| N/A   45C    P0             62W /  300W |     311MiB /  32768MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  Tesla V100-SXM2-32GB           On  |   00000000:88:00.0 Off |                    0 |
| N/A   46C    P0             61W /  300W |     311MiB /  32768MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   3  Tesla V100-SXM2-32GB           On  |   00000000:8A:00.0 Off |                    0 |
| N/A   46C    P0             60W /  300W |     311MiB /  32768MiB |      1%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A    672556      C   ...ech/tkc/commun/venv/v100/bin/python        308MiB |
|    1   N/A  N/A    672556      C   ...ech/tkc/commun/venv/v100/bin/python        308MiB |
|    2   N/A  N/A    672556      C   ...ech/tkc/commun/venv/v100/bin/python        308MiB |
|    3   N/A  N/A    672556      C   ...ech/tkc/commun/venv/v100/bin/python        308MiB |
+-----------------------------------------------------------------------------------------+

@ASKabalan ASKabalan added the bug Something isn't working label Jul 1, 2024
@yashk2810
Copy link
Collaborator

If you upgrade to the latest nightly, you will see a better error: 89c404e

You are closing over a global jax.Array which is not allowed. Instead pass it as an argument to the function.

I am closing the issue since this is WAI but please re-open if passing the array as an argument still fails.

@ASKabalan
Copy link
Author

Yup, that was it. It slipped my mind completely.

Thank you 🙏 🙏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants