Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[How?] Debugging crash in jax.jit function when running under pytest-xdist #10242

Open
PhilipVinc opened this issue Apr 12, 2022 · 4 comments
Open
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue.

Comments

@PhilipVinc
Copy link
Contributor

For NetKet we run the test-suite under pytest-xdist (for those not familiar with it, xdist is an extension that forks the python process and then splits the test suite among those processes).

For most of all our tests, we run calculations with relatively small matrix sizes, however a few tests involve compiling some fairly involved functions/gradients which take quite some time.

We run our tests both locally and under GitHub-actions (2 processes).

Between 2 and 3 months ago we started seeing spurious, non-reproducible failures where the worker (a pytest process) crashes without a stack trace. The test that crashes is always the same.

This only happens when running under this distributed pytest-xdist (pytest -nX). If you disable xdist those crashes never occur (pytest -n0).

At first I thought it could be an OOM error, because I could not reproduce them on my beefy workstation, but by running the tests several times I am able to reproduce the tests locally as well.

You can see here the test that is crashing https://github.com/netket/netket/runs/5978272340?check_suite_focus=true#step:8:5977

I tried to debug it for a while, and the culprit is a pretty large jai-compiled function.

Do you have any suggestion on what I could do to instrument jax to understand what is happening?

@PhilipVinc PhilipVinc added the bug Something isn't working label Apr 12, 2022
@hawkinsp
Copy link
Collaborator

We also run our Github CI tests with pytest-xdist, but I'm actually advocating we migrate away from it because it makes it impossible to isolate and debug this kind of problem. We've seen a few random segfaults in some of our tests on our Github CI as well, but I have been completely unable to reproduce them outside Github actions.

If you can reproduce this on a workstation and it's, say a SIGSEGV, I think what I would probably try to do is capture a core dump (ulimit -c unlimited). We can then inspect the core dump after the fact with gdb. The stack trace in the core dump would probably tell us something (bt in gdb), although we may need to rebuild JAX with debug symbols to get a meaningful output.

Is it easy to reproduce?

@PhilipVinc
Copy link
Contributor Author

How do I tell how it crashed? The only thing pytest tells me is [gw1] node down: Not properly terminated [5959](https://github.com/netket/netket/runs/5978272340?check_suite_focus=true#step:8:5959).

I can try to reproduce it on Github Actions manually by using termux.
Do you have some prebuilt linux cpu jax wheels with debugging symbols I can use?

@hawkinsp
Copy link
Collaborator

I suspect that means the job segfaulted, so it should dump a core if you raise the ulimit.

We don't have prebuilt wheels with symbols.

It would be most helpful if we can reproduce outside the github actions environment...

@PhilipVinc
Copy link
Contributor Author

I don't have time this week but will try early next week.

Ill post here again.

Thanks for your help in the meantime.

@froystig froystig added the needs info More information is required to diagnose & prioritize the issue. label Apr 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs info More information is required to diagnose & prioritize the issue.
Projects
None yet
Development

No branches or pull requests

3 participants