[How?] Debugging crash in jax.jit function when running under pytest-xdist #10242
Labels
bug
Something isn't working
needs info
More information is required to diagnose & prioritize the issue.
For NetKet we run the test-suite under pytest-xdist (for those not familiar with it, xdist is an extension that forks the python process and then splits the test suite among those processes).
For most of all our tests, we run calculations with relatively small matrix sizes, however a few tests involve compiling some fairly involved functions/gradients which take quite some time.
We run our tests both locally and under GitHub-actions (2 processes).
Between 2 and 3 months ago we started seeing spurious, non-reproducible failures where the worker (a pytest process) crashes without a stack trace. The test that crashes is always the same.
This only happens when running under this distributed pytest-xdist (pytest -nX). If you disable xdist those crashes never occur (
pytest -n0
).At first I thought it could be an OOM error, because I could not reproduce them on my beefy workstation, but by running the tests several times I am able to reproduce the tests locally as well.
You can see here the test that is crashing https://github.com/netket/netket/runs/5978272340?check_suite_focus=true#step:8:5977
I tried to debug it for a while, and the culprit is a pretty large jai-compiled function.
Do you have any suggestion on what I could do to instrument jax to understand what is happening?
The text was updated successfully, but these errors were encountered: