Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax crashes on TPU in version 0.3.19 #12550

Closed
gerdm opened this issue Sep 28, 2022 · 10 comments
Closed

Jax crashes on TPU in version 0.3.19 #12550

gerdm opened this issue Sep 28, 2022 · 10 comments
Assignees
Labels
bug Something isn't working P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)

Comments

@gerdm
Copy link

gerdm commented Sep 28, 2022

Description

Hi,

I installed Jax on a TPU V3-8:

pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

However, when running Jax, I get the following error.

(base) gerardoduran@t1v-n-7177f451-w-0:~$ python
Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
j>>> import jax.numpy as jnp
>>> jnp.sqrt(2)
tcmalloc: large alloc 377396076847104 bytes == (nil) @  0x7fefdec9d680 0x7fefdecbe824 0x7fedeee9d2da 0x7fedeee4ebae 0x7fedea76487a 0x7fedea763a05 0x7fedea765c62 0x7fede9a4a82e 0x7fee98171b66 0x7fee971e4541 0x7fee971d0912 0x7fee971ca61d 0x7fee95b3048c 0x7fee95b3f176 0x7fee93f64ded 0x7fee93d40ac8 0x7fee93d412d3 0x7fee93d1c916 0x55ced19903cc 0x55ced1989738 0x55ced199df80 0x55ced1981107 0x55ced199086f 0x55ced198299f 0x55ced199086f 0x55ced19800ff 0x55ced199086f 0x55ced19800ff 0x55ced199086f 0x55ced199e7f8 0x55ced198299f
Unhandled exception:
    @     0x7fedeedb9b62  (unknown)
    @     0x7fedeeeba4e6  (unknown)
    @     0x7fedeeeba03b  (unknown)
    @     0x7fedeeeb9fb4  (unknown)
    @     0x7fedeee9d32b  (unknown)
    @     0x7fedeee4ebae  (unknown)
    @     0x7fedea76487a  (unknown)
    @     0x7fedea763a05  (unknown)
    @     0x7fedea765c62  (unknown)
    @     0x7fede9a4a82e  TpuCompiler_RunHloPasses
    @     0x7fee98171b66  xla::(anonymous namespace)::TpuCompiler::RunHloPasses()
    @     0x7fee971e4541  xla::Service::BuildExecutable()
    @     0x7fee971d0912  xla::LocalService::CompileExecutables()
    @     0x7fee971ca61d  xla::LocalClient::Compile()
    @     0x7fee95b3048c  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fee95b3f176  xla::PjRtStreamExecutorClient::Compile()
    @     0x7fee93f64ded  xla::PyClient::CompileMlir()
    @     0x7fee93d40ac8  pybind11::detail::argument_loader<>::call_impl<>()
    @     0x7fee93d412d3  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7fee93d1c916  pybind11::cpp_function::dispatcher()
    @     0x55ced19903cc  cfunction_call
https://symbolize.stripped_domain/r/?trace=7fedeedb9b62,7fedeeeba4e5,7fedeeeba03a,7fedeeeb9fb3,7fedeee9d32a,7fedeee4ebad,7fedea764879,7fedea763a04,7fedea765c61,7fede9a4a82d,7fee98171b65,7fee971e4540,7fee971d0911,7fee971ca61c,7fee95b3048b,7fee95b3f175,7fee93f64dec,7fee93d40ac7,7fee93d412d2,7fee93d1c915,55ced19903cb&map=ca08008df67fa564c14ead76d3f2385a:7feddef57000-7fedef062c00 
libc++abi: terminating due to uncaught exception of type std::bad_alloc: std::bad_alloc
https://symbolize.stripped_domain/r/?trace=7fefde94600b,7fefdec6f41f,7fedeeea17c8,7fedeeeba4e5,7fedeeeba03a,7fedeeeb9fb3,7fedeee9d32a,7fedeee4ebad,7fedea764879,7fedea763a04,7fedea765c61,7fede9a4a82d,7fee98171b65,7fee971e4540,7fee971d0911,7fee971ca61c,7fee95b3048b,7fee95b3f175,7fee93f64dec,7fee93d40ac7,7fee93d412d2,7fee93d1c915,55ced19903cb&map=ca08008df67fa564c14ead76d3f2385a:7feddef57000-7fedef062c00 
*** SIGABRT received by PID 12344 (TID 12344) on cpu 47 from PID 12344; ***
E0928 10:21:25.866065   12344 coredump_hook.cc:395] RAW: Remote crash data gathering hook invoked.
E0928 10:21:25.866084   12344 coredump_hook.cc:441] RAW: Skipping coredump since rlimit was 0 at process start.
E0928 10:21:25.866093   12344 client.cc:243] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0928 10:21:25.866101   12344 coredump_hook.cc:502] RAW: Sending fingerprint to remote end.
E0928 10:21:25.866109   12344 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0928 10:21:25.866121   12344 coredump_hook.cc:506] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0928 10:21:25.866130   12344 coredump_hook.cc:580] RAW: Discarding core.
E0928 10:21:26.115471   12344 process_state.cc:774] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

I've tried both reinstalling Jax and create a new TPU V3-8, but I get the exact same error.

Running jax.devices() does show the TPUs I have on the VM

>>> jax.devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]

The problem seems to be related to the execution of Jax on the TPU. If I replicate @mattjj's code in this issue, I'm able to run Jax on the cpu-defined function, but not the tpu-defined one.

(base) gerardoduran@t1v-n-7177f451-w-0:~$ python
Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from jax import jit
>>> def f(x): return x**2
... 
>>> f_cpu = jit(f, backend='cpu')
>>> f_tpu = jit(f, backend='tpu')
>>> 
>>> f_cpu(2)
DeviceArray(4, dtype=int32, weak_type=True)
>>> f_tpu(2)
tcmalloc: large alloc 378179171590144 bytes == (nil) @  0x7f78decaf680 0x7f78decd0824 0x7f76eeeaf2da 0x7f76eee60bae 0x7f76ea77687a 0x7f76ea775a05 0x7f76ea777c62 0x7f76e9a5c82e 0x7f7798183b66 0x7f77971f6541 0x7f77971e2912 0x7f77971dc61d 0x7f7795b4248c 0x7f7795b51176 0x7f7793f76ded 0x7f7793d52ac8 0x7f7793d532d3 0x7f7793d2e916 0x55fc669c83cc 0x55fc669c1738 0x55fc669d5f80 0x55fc669b9107 0x55fc669c886f 0x55fc669ba99f 0x55fc669c886f 0x55fc669b80ff 0x55fc669c886f 0x55fc669b80ff 0x55fc669c886f 0x55fc669d67f8 0x55fc669ba99f
# ... more errors

What jax/jaxlib version are you using?

jax 0.3.19 / jaxlib 0.3.15

Which accelerator(s) are you using?

TPU

Additional system info

No response

NVIDIA GPU info

No response

@gerdm gerdm added the bug Something isn't working label Sep 28, 2022
@gerdm
Copy link
Author

gerdm commented Sep 28, 2022

update: it seems that the problem has to do with version 0.3.19 of Jax. If I downgrade to 0.3.17 I don't get the error any longer

Python 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:35:26) [GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.3.17'
>>> jax.numpy.sqrt(2)
DeviceArray(1.4142135, dtype=float32, weak_type=True)

@gerdm gerdm changed the title Jax crashes on TPU V3-8 Jax crashes on TPU in version 0.3.19 Sep 28, 2022
@hawkinsp hawkinsp self-assigned this Sep 28, 2022
@hawkinsp
Copy link
Collaborator

The issue is that the new jax release pins an incompatible libtpu version. We messed up.

Another workaround for now is to install the new jax but to downgrade your libtpu version:

pip install libtpu-nightly==0.1.dev20220723 -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

I'm working on a new release with a fix.

@hawkinsp hawkinsp added the P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) label Sep 28, 2022
@dluo96
Copy link

dluo96 commented Sep 28, 2022

@hawkinsp I just tested this in a simple Python 3.8 image (see here), unfortunately this workaround does not seem to work for me.

@hawkinsp
Copy link
Collaborator

@dluo96 Does that happen outside of a docker container? I think that issue is specifically related to docker.

@pcuenca
Copy link

pcuenca commented Sep 28, 2022

@hawkinsp the problem happens outside docker for me, but jax 0.3.17 fixes it.

This is my repro on v2-32 in case it helps:

TPU_CHIPS_PER_PROCESS_BOUNDS=1,1,1 TPU_PROCESS_BOUNDS=1,1,1 TPU_VISIBLE_DEVICES=0,1,2,3 python -c "import jax; jax.random.PRNGKey(0)" # hangs forever or crashes

Running Python 3.9.12 in my case.

@hawkinsp
Copy link
Collaborator

@pcuenca That doesn't sound like the same issue reported in the first post of this issue.

@pcuenca
Copy link

pcuenca commented Sep 28, 2022

@hawkinsp Sorry, I assumed it was the same because the behaviour is similar and affects the same version. I can open a new issue with any details you need, no problem :)

@dluo96
Copy link

dluo96 commented Sep 28, 2022

Hi @hawkinsp, thanks for looking into this. I tested this with the same version of Docker (Docker version 20.10.6, build 370c289) and an earlier version of JAX (0.3.13) and it worked (I've detailed a minimum reproducible example in #12548): I ran this inside a Python 3.8 container on the TPU VM:

Python 3.8.14 (default, Sep 13 2022, 15:03:48)
[GCC 10.2.1 20210110] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.3.13'
>>> jax.local_devices()
[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
>>> key = jax.random.PRNGKey(0)
>>> key
DeviceArray([0, 0], dtype=uint32)

@hawkinsp
Copy link
Collaborator

This should be fixed by jax/jaxlib v0.3.20 which we just released. Hope that helps!

@dluo96
Copy link

dluo96 commented Sep 29, 2022

That solved it for me, thank you so much @hawkinsp! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working P1 (soon) Assignee is working on this now, among other tasks. (Assignee required)
Projects
None yet
Development

No branches or pull requests

5 participants