-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-GPU Workflow Hangs #11637
Comments
I know you mentioned you couldn't reproduce on a 4xT4 node @hawkinsp. Based on your experience, do you have a recommended path we can explore on our side to mitigate? As you can imagine, this is a blocker. |
Actually I have reproduced it and I'm looking at it. It's not deterministic; I just needed to run it more times. |
Thanks! Let me know if I can be of assistance in running any tests / dev branches. |
I think the XLA patch below fixes things. I'll send a cleaned-up version of it out for review and it should make it into the next jaxlib. The patch adds a barrier that ensures all participants complete Thread A:
Thread B:
This is a deadlock: neither thread can make progress. We can avoid the problem by requiring a barrier after the calls to ncclCommInitRank(), requiring all GPUs to finish initialization before any of them can issue their collective operation.
|
Thanks for such a prompt response! |
… ncclCommInitRank() before allowing any participant to proceed. Without a barrier, we can experience deadlocks. As best I understand it, the deadlock scenario looks like this: Thread A: * calls ncclCommInitRank(), which succeeds, * issues the collective operation, * calls an operation that manipulates the device page tables, e.g., copying a device buffer to an unpinned host buffer. * Since this action manipulates the device page tables, it seems that this action blocks waiting for the device stream. Thread B: * calls ncclCommInitRank(), which calls cudaMalloc(). * cudaMalloc() also manipulates device page tables, and cannot proceed without acquiring an internal lock around the device page table state But thread A already holds this lock, but thread A cannot make progress until thread B issues its collective. This is a deadlock: neither thread can make progress. We can avoid the problem by requiring a barrier after the calls to ncclCommInitRank(), requiring all GPUs to finish initialization before any of them can issue their collective operation. Fixes google/jax#11637 PiperOrigin-RevId: 464164328
This is now fixed at head; you'll have to build |
Potentially related to #10969.
When running the following code on a single device, it will run as expected. However, if n_devices > 1, I'm observing all devices will allocate memory but an inconsistent number will perform computation (as displayed via nvtop / nvidia-smi), leading to an indefinite hang.
Environment:
While the actual code in which I'm seeing this occur is far more complex, I've created a simpler repro.
The text was updated successfully, but these errors were encountered: