Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] The code for deepspeed.comm.comm.monitored_barrier() #4488

Open
phalexo opened this issue Oct 9, 2023 · 12 comments
Open

[BUG] The code for deepspeed.comm.comm.monitored_barrier() #4488

phalexo opened this issue Oct 9, 2023 · 12 comments
Assignees
Labels
bug Something isn't working training

Comments

@phalexo
Copy link

phalexo commented Oct 9, 2023

takes a "timeout" parameter but then INTERNALLY calls a function

*.barrier() that does NOT take the "timeout" paramter.

So, it is useless.

Describe the bug
A clear and concise description of what the bug is.

To Reproduce
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Expected behavior
A clear and concise description of what you expected to happen.

ds_report output
Please run ds_report to give us details about your setup.

Screenshots
If applicable, add screenshots to help explain your problem.

System info (please complete the following information):

  • OS: [e.g. Ubuntu 18.04]
  • GPU count and types [e.g. two machines with x8 A100s each]
  • Interconnects (if applicable) [e.g., two machines connected with 100 Gbps IB]
  • Python version
  • Any other relevant info about your setup

Launcher context
Are you launching your experiment with the deepspeed launcher, MPI, or something else?

Docker context
Are you using a specific docker image that you can share?

Additional context
Add any other context about the problem here.

@phalexo phalexo added bug Something isn't working training labels Oct 9, 2023
@mrwyattii mrwyattii self-assigned this Oct 9, 2023
@mrwyattii
Copy link
Contributor

@phalexo thanks for bringing this to our attention. I believe that barrier was mistakenly used instead of monitored_barrier here:

return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)

@Quentin-Anthony you were the last one to touch this line. Can you comment? Thanks!

@phalexo
Copy link
Author

phalexo commented Oct 9, 2023

@phalexo thanks for bringing this to our attention. I believe that barrier was mistakenly used instead of monitored_barrier here:

return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)

@Quentin-Anthony you were the last one to touch this line. Can you comment? Thanks!

Thank you for looking at this. The initial problem for me is that there does not seem to be a way to change the timeout. I went as far as changing the hardcoded value of "default_pg_time" to 5 hours, but for some reason at 30 minutes an exception is raised anyway.

There must be some way to control the timeout for the default group.

@phalexo
Copy link
Author

phalexo commented Oct 10, 2023

Fixing this bug may or not allow to change the timeout. See below.

#pragma once

#include <condition_variable>
#include
#include
#include
#include <unordered_map>
#include
#include

#include <ATen/ATen.h>
#include <c10/macros/Macros.h>

#include <torch/csrc/distributed/c10d/Work.hpp>
#include <torch/csrc/distributed/c10d/Types.hpp>
#include <torch/csrc/distributed/c10d/Utils.hpp>
#include <torch/csrc/distributed/c10d/debug.h>
#include <torch/csrc/distributed/c10d/sequence_num.hpp>

constexpr auto kBackendDefaultTimeout =
std::chrono::milliseconds(30 * 60 * 1000); // This is mostly likely the culprit, nested in several layers of include files.

namespace c10d {

class TORCH_API Backend : public torch::CustomClassHolder {
public:

// Backend Options is a base struct that defines the basic options
// when constructing a Backend. Each Backend subclass should
// extend this struct and define its options if it wants to provide more
// config options (beyond basic ones defined here) to end user.

FILE: ~/pytorch/torch/csrc/distributed/c10d/Backend.hpp

@Quentin-Anthony
Copy link
Contributor

@phalexo thanks for bringing this to our attention. I believe that barrier was mistakenly used instead of monitored_barrier here:

return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)

@Quentin-Anthony you were the last one to touch this line. Can you comment? Thanks!

@mrwyattii -- You're correct! Looks like a typo. I introduced a PR in #4496.

@phalexo -- I believe the cause of your issue is that torch.distributed.barrier() doesn't have a timeout arg, so your deepspeed.comm.monitored_barrier() call dropped the timeout arg. Since my patch in #4496 corrects us to instead route to torch.distributed.monitored_barrier(), your timeout should get picked up. I was able to verify this myself with the nccl backend and timeouts of both 5 and 45 minutes. Give it a try and let me know if it works for you!

@phalexo
Copy link
Author

phalexo commented Oct 11, 2023

@phalexo thanks for bringing this to our attention. I believe that barrier was mistakenly used instead of monitored_barrier here:

return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)

@Quentin-Anthony you were the last one to touch this line. Can you comment? Thanks!

@mrwyattii -- You're correct! Looks like a typo. I introduced a PR in #4496.

@phalexo -- I believe the cause of your issue is that torch.distributed.barrier() doesn't have a timeout arg, so your deepspeed.comm.monitored_barrier() call dropped the timeout arg. Since my patch in #4496 corrects us to instead route to torch.distributed.monitored_barrier(), your timeout should get picked up. I was able to verify this myself with the nccl backend and timeouts of both 5 and 45 minutes. Give it a try and let me know if it works for you!

I am not sure how you managed to test it with 'nccl' backend. This is what I got:

Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:20<00:00, 10.06s/it]
Using pad_token, but it is not set yet.
Traceback (most recent call last):
File "fine-tune.py", line 260, in
train()
File "fine-tune.py", line 198, in train
monitored_barrier(timeout=datetime.timedelta(minutes=300))
File "/home/developer/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3380, in monitored_barrier
raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")
RuntimeError: monitored_barrier is only implemented for GLOO backend.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:21<00:00, 10.57s/it]
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:21<00:00, 10.72s/it]
Using pad_token, but it is not set yet.
Traceback (most recent call last):
File "fine-tune.py", line 260, in
train()
File "fine-tune.py", line 198, in train
monitored_barrier(timeout=datetime.timedelta(minutes=300))
File "/home/developer/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3380, in monitored_barrier
raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")
RuntimeError: monitored_barrier is only implemented for GLOO backend.
Using pad_token, but it is not set yet.
Have loaded the dataset.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:22<00:00, 11.36s/it]
Using pad_token, but it is not set yet.
Traceback (most recent call last):
File "fine-tune.py", line 260, in
train()
File "fine-tune.py", line 198, in train
monitored_barrier(timeout=datetime.timedelta(minutes=300))
File "/home/developer/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3380, in monitored_barrier
raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")
RuntimeError: monitored_barrier is only implemented for GLOO backend.

There is this bit of code, in the distributed_c10d.py file.

if get_backend(group) != Backend.GLOO:
    raise RuntimeError("monitored_barrier is only implemented for GLOO backend.")

It seems only the deepspeed/comm/torch.py has been modified, but deepspeed/comm/comm.py remains the same.

I will test with both and see if it makes a difference.

So, torch.distributed.monitored_barrier() is NOT implemented for "nccl', just for 'gloo'

deepspeed.comm.comm.py has not been modified, so it gives the same error:

Traceback (most recent call last):
File "fine-tune.py", line 260, in
train()
File "fine-tune.py", line 201, in train
deepspeed.comm.barrier(timeout=datetime.timedelta(minutes=300))
File "/home/developer/mambaforge/envs/FinGPT/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
return func(*args, **kwargs)
TypeError: barrier() got an unexpected keyword argument 'timeout'
Using pad_token, but it is not set yet.
Traceback (most recent call last):
File "fine-tune.py", line 260, in
train()
File "fine-tune.py", line 201, in train
deepspeed.comm.barrier(timeout=datetime.timedelta(minutes=300))
File "/home/developer/mambaforge/envs/FinGPT/lib/python3.8/site-packages/deepspeed/comm/comm.py", line 117, in log_wrapper
return func(*args, **kwargs)
TypeError: barrier() got an unexpected keyword argument 'timeout'

And then there is this:

constexpr auto kBackendDefaultTimeout =
std::chrono::milliseconds(30 * 60 * 1000); // This is mostly likely the culprit, nested in several layers of include files.

namespace c10d {

class TORCH_API Backend : public torch::CustomClassHolder {
public:

// Backend Options is a base struct that defines the basic options
// when constructing a Backend. Each Backend subclass should
// extend this struct and define its options if it wants to provide more
// config options (beyond basic ones defined here) to end user.

FILE: ~/pytorch/torch/csrc/distributed/c10d/Backend.hpp

@phalexo
Copy link
Author

phalexo commented Oct 11, 2023

The exception happens deep inside the C/C++ code in pytorch/nccl, @Quentin-Anthony

Somehow one has to be able to adjust the timeout inside the backend.

frame #29: PyEval_EvalCodeEx + 0x39 (0x5568d06e5fd9 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #30: PyEval_EvalCode + 0x1b (0x5568d06e5f9b in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #31: + 0x1eb929 (0x5568d0706929 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #32: + 0x1ea923 (0x5568d0705923 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #33: + 0x9a00f (0x5568d05b500f in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #34: PyRun_SimpleFileExFlags + 0x364 (0x5568d05b4b13 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #35: + 0x8cfdc (0x5568d05a7fdc in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #36: Py_BytesMain + 0x39 (0x5568d06d97b9 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #37: __libc_start_main + 0xf3 (0x7fca24cf9083 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #38: + 0x1be6bd (0x5568d06d96bd in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
. This may indicate a possible application crash on rank 0 or a network set up issue.
Traceback (most recent call last):
File "fine-tune.py", line 263, in
train()
File "fine-tune.py", line 201, in train
barrier()
File "/home/developer/.local/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 3313, in barrier
work = default_pg.barrier(opts=opts)
RuntimeError: [3] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Socket Timeout
Exception raised from recvBytes at ../torch/csrc/distributed/c10d/Utils.hpp:604 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7f298cf424d7 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, char const*) + 0x68 (0x7f298cf0c434 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libc10.so)
frame #2: c10d::TCPStore::doWait(c10::ArrayRefstd::string, std::chrono::duration<long, std::ratio<1l, 1000l> >) + 0xd8 (0x7f28fc3d87c8 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #3: c10d::TCPStore::doGet(std::string const&) + 0x22 (0x7f28fc3d9472 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #4: c10d::TCPStore::get(std::string const&) + 0x59 (0x7f28fc3d94f9 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #5: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7f28fc3989c1 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #6: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7f28fc3989c1 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #7: c10d::PrefixStore::get(std::string const&) + 0x31 (0x7f28fc3989c1 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #8: c10d::ProcessGroupNCCL::broadcastUniqueNCCLID(ncclUniqueId*, bool, std::string const&, int) + 0xaf (0x7f28ab3a7eef in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #9: c10d::ProcessGroupNCCL::getNCCLComm(std::string const&, std::vector<c10::Device, std::allocatorc10::Device > const&, c10d::OpType, int, bool) + 0x201 (0x7f28ab3abba1 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #10: + 0x1322ccd (0x7f28ab3b2ccd in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #11: c10d::ProcessGroupNCCL::allreduce_impl(std::vector<at::Tensor, std::allocatorat::Tensor >&, c10d::AllreduceOptions const&) + 0x21 (0x7f28ab3b40d1 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #12: c10d::ProcessGroupNCCL::allreduce(std::vector<at::Tensor, std::allocatorat::Tensor >&, c10d::AllreduceOptions const&) + 0x39d (0x7f28ab3b6d8d in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #13: c10d::ProcessGroupNCCL::barrier(c10d::BarrierOptions const&) + 0x851 (0x7f28ab3c5ce1 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cuda.so)
frame #14: + 0x54861d9 (0x7f28fc38d1d9 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #15: + 0x5489f2a (0x7f28fc390f2a in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #16: + 0x5498d90 (0x7f28fc39fd90 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_cpu.so)
frame #17: + 0xb7409e (0x7f2910fd009e in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #18: + 0x3bfba0 (0x7f291081bba0 in /home/developer/.local/lib/python3.8/site-packages/torch/lib/libtorch_python.so)
frame #19: PyCFunction_Call + 0x52 (0x5605b928ffd2 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #20: _PyObject_MakeTpCall + 0x3db (0x5605b927b29b in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #21: + 0x136d7d (0x5605b928fd7d in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #22: _PyEval_EvalFrameDefault + 0x11bb (0x5605b92734ab in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #23: + 0x1cbe00 (0x5605b9324e00 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #24: _PyEval_EvalFrameDefault + 0x3aa (0x5605b927269a in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #25: _PyEval_EvalCodeWithName + 0x2f1 (0x5605b9271261 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #26: _PyFunction_Vectorcall + 0x19c (0x5605b928289c in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #27: _PyEval_EvalFrameDefault + 0x3aa (0x5605b927269a in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #28: _PyEval_EvalCodeWithName + 0x2f1 (0x5605b9271261 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #29: PyEval_EvalCodeEx + 0x39 (0x5605b9323fd9 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #30: PyEval_EvalCode + 0x1b (0x5605b9323f9b in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #31: + 0x1eb929 (0x5605b9344929 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #32: + 0x1ea923 (0x5605b9343923 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #33: + 0x9a00f (0x5605b91f300f in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #34: PyRun_SimpleFileExFlags + 0x364 (0x5605b91f2b13 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #35: + 0x8cfdc (0x5605b91e5fdc in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #36: Py_BytesMain + 0x39 (0x5605b93177b9 in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
frame #37: __libc_start_main + 0xf3 (0x7f2992fe4083 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #38: + 0x1be6bd (0x5605b93176bd in /home/developer/mambaforge/envs/FinGPT/bin/python3.8)
. This may indicate a possible application crash on rank 0 or a network set up issue.

@phalexo
Copy link
Author

phalexo commented Oct 12, 2023

@phalexo thanks for bringing this to our attention. I believe that barrier was mistakenly used instead of monitored_barrier here:

return cdb.barrier(group=group, timeout=timeout, wait_all_ranks=wait_all_ranks)

@Quentin-Anthony you were the last one to touch this line. Can you comment? Thanks!

@mrwyattii -- You're correct! Looks like a typo. I introduced a PR in #4496.

@phalexo -- I believe the cause of your issue is that torch.distributed.barrier() doesn't have a timeout arg, so your deepspeed.comm.monitored_barrier() call dropped the timeout arg. Since my patch in #4496 corrects us to instead route to torch.distributed.monitored_barrier(), your timeout should get picked up. I was able to verify this myself with the nccl backend and timeouts of both 5 and 45 minutes. Give it a try and let me know if it works for you!

As an experiment I replaced "barrier" calls with "all_reduce" since essentially from the timing point of view it does the same synchronization as "barrier" calls.

I got exactly the same error, with the socket timeout, except the function name in the trace was now "all_reduce" instead of "barrier".

Could you post your code to experiment with 5, 45 minute timeouts?

@Quentin-Anthony
Copy link
Contributor

It seems only the deepspeed/comm/torch.py has been modified, but deepspeed/comm/comm.py remains the same.

I don't think this is correct? Here's the PR change: https://github.com/microsoft/DeepSpeed/pull/4496/files#diff-2bf8bf2cd0ea4637cdd3397e6e27d38ce59e0ebc8daafcc583dbfcd8c092b0beR419

So, torch.distributed.monitored_barrier() is NOT implemented for "nccl', just for 'gloo'

Yep you're right. I was using an older custom version of torch with monitored_barrier implemented for NCCL in a hacky way (spawns an MPI process group to track ranks instead of Gloo), but I realize that could be confusing so I apologize.

What I ran was essentially the following:

import deepspeed
import deepspeed.comm as dist

if __name__ == "__main__":
	deepspeed.init_distributed("gloo", dist_init_required=True)
	if dist.get_rank() != 1:
		dist.monitored_barrier(timeout=timedelta(minutes=5))

Can you try this?

Also, I think your socket timeouts:

RuntimeError: [3] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Socket Timeout

Are completely unrelated to the monitored_barrier timeout arg. This is a generic error message and the error source is unclear from your trace. I suspect you would get the same error by calling the pure torch.distributed calls without DeepSpeed, especially if you get the same error with AllReduce.

@phalexo
Copy link
Author

phalexo commented Oct 16, 2023

@Quentin-Anthony
I will try it with "gloo" backend. Edit: it was able to get past 30 minutes with "GLOO" as the backend, and now I have a pretty clear notion what is happening.

That said, what I think happens is this:

if rank > 0:
#barrier() // This would time out in 30 minutes, then ranks > 0 would continue running and trying to communicate with 0,
// but rank 0 is still busy with its preprocessing, so now the socket comms timeout too.
deepspeed.distributed.monitored_barrier(timeout=timedelta(minutes=300)) // would block ranks > 0 for much longer than 30min
// so rank 0 continues doing its work past 30min

 // lengthy training data processing that exceeds 30 minutes.

if rank == 0:
#barrier() // This is not reached until rank 0 completes all the preprocessing. If it goes past 30 minutes, as far as ranks > 0
// are concerned rank 0 died. They attempt to communications but sockets timeout. This barrier is NEVER reached.
// ranks > 0 time out and throw an exception
deepspeed.distributed.monitored_barrier(timeout=timedelta(minutes=300))

All ranks above 0 are stopped at the barrier() above, waiting for rank 0 to finish preprocessing and rejoin at the second barrier(). When about 30 minutes hit, they assume rank 0 is dead and some kind of exception processing happens. I have looked at some settings for sockets and the timeout is 10 minutes, but there are 3 retries for sockets. So maybe 30 minutes come from that.

Why isn't monitored_barrier implemented for nccl? I thought it is faster than gloo.

if dist.get_rank() != 1: // Should this be 1? or -1, just a choice? or 0?

p.s. If I change the backend to gloo and barrier() to monitored_barrier(timeout=timedelta(minutes)) the function should internally adjust the timeout for sockets too. What is the point of a having a process wait at a barrier if a socket timeout causes an exception anyway? Sockets have to be set to "keep_alive" or a longer timeout.

CONCLUSIONS:
NCCL has better performance than GLOO. So it is certainly still a problem for NCCL backend, and anyone who needs to save money on electricity or GPU rentals, etc.... for training.

It is still a problem for GLOO, unless one is aware that one cannot use barrier() and has to use monitored_barrier() with a timeout.

So, the socket timeout is related to the barrier timeout in the sense that a correctly working barrier with a longer timeout prevents ranks > 0 from reaching the socket timeout point.

I don't understand though why changing the default timeout values in torch source, in Backend.hpp, ProcessGroup.hpp and rebuilding PyTorch did not solve the problem.

It seems only the deepspeed/comm/torch.py has been modified, but deepspeed/comm/comm.py remains the same.

I don't think this is correct? Here's the PR change: https://github.com/microsoft/DeepSpeed/pull/4496/files#diff-2bf8bf2cd0ea4637cdd3397e6e27d38ce59e0ebc8daafcc583dbfcd8c092b0beR419

So, torch.distributed.monitored_barrier() is NOT implemented for "nccl', just for 'gloo'

Yep you're right. I was using an older custom version of torch with monitored_barrier implemented for NCCL in a hacky way (spawns an MPI process group to track ranks instead of Gloo), but I realize that could be confusing so I apologize.

What I ran was essentially the following:

import deepspeed
import deepspeed.comm as dist

if __name__ == "__main__":
	deepspeed.init_distributed("gloo", dist_init_required=True)
	if dist.get_rank() != 1:
		dist.monitored_barrier(timeout=timedelta(minutes=5))

Can you try this?

Also, I think your socket timeouts:

RuntimeError: [3] is setting up NCCL communicator and retrieving ncclUniqueId from [0] via c10d key-value store by key '0', but store->get('0') got error: Socket Timeout

Are completely unrelated to the monitored_barrier timeout arg. This is a generic error message and the error source is unclear from your trace. I suspect you would get the same error by calling the pure torch.distributed calls without DeepSpeed, especially if you get the same error with AllReduce.

@Quentin-Anthony
Copy link
Contributor

Why isn't monitored_barrier implemented for nccl? I thought it is faster than gloo.

Because monitored_barrier uses a CPU process group to monitor ranks and NCCL Is GPU-only.

if dist.get_rank() != 1: // Should this be 1? or -1, just a choice? or 0?

Just a choice. My code is just triggering the case where a single rank (1 in this case) doesn't enter the barrier.

I have looked at some settings for sockets and the timeout is 10 minutes, but there are 3 retries for sockets. So maybe 30 minutes come from that.

Wouldn't the 30-minute timeout just be from the default pytorch pg timeout?

timeout=default_pg_timeout,

and
https://github.com/pytorch/pytorch/blob/5c3955200c3ca1916f669a32d99cd7976f3d4930/torch/csrc/distributed/c10d/ProcessGroup.hpp#L25

@phalexo
Copy link
Author

phalexo commented Oct 16, 2023

Why isn't monitored_barrier implemented for nccl? I thought it is faster than gloo.

Because monitored_barrier uses a CPU process group to monitor ranks and NCCL Is GPU-only.

if dist.get_rank() != 1: // Should this be 1? or -1, just a choice? or 0?

Just a choice. My code is just triggering the case where a single rank (1 in this case) doesn't enter the barrier.

I have looked at some settings for sockets and the timeout is 10 minutes, but there are 3 retries for sockets. So maybe 30 minutes come from that.

Wouldn't the 30-minute timeout just be from the default pytorch pg timeout?

timeout=default_pg_timeout,

and
https://github.com/pytorch/pytorch/blob/5c3955200c3ca1916f669a32d99cd7976f3d4930/torch/csrc/distributed/c10d/ProcessGroup.hpp#L25

default_pg_timeout itself is NOT set to an actual value, but instead is set to another variable.

This is a recent change

default_pg_timeout = timedelta(minutes=int(os.getenv("DEEPSPEED_TIMEOUT", default=30)))
INFERENCE_GENERIC_MODE = 'generic'
INFERENCE_SPECIALIZED_MODE = 'specialized'

Although this works to change the value, it does not appear to work to change the actual timeout. I don't know why.

It is still a total puzzle to me why values defined in Backend.hpp and ProcessGroup.hpp are not controlling all the other values, including default_pg_timeout.

I built pytorch from source after modifying those timeout values by 10x, but it had no effect.

If nccl is strictly GPU side, then why can't one have CPU side functions for barriers? Why do I need to change the backend? I am not quite getting this. I am not trying to synch the GPU processes, just the CPU side preprocessing.

@phalexo
Copy link
Author

phalexo commented Oct 17, 2023

Still a problem, even with GLOO backend. The timeout for barriers is set for 10 hours, but they timed out at 6.6 apparently.

The argument trust_remote_code is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:27<00:00, 13.69s/it]
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:28<00:00, 14.35s/it]
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.12s/it]
Using pad_token, but it is not set yet.
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:30<00:00, 15.45s/it]
Using pad_token, but it is not set yet.
Map (num_proc=14): 70%|██████████████████████████████████████████████████████████████████▌ | 652383/930514 [6:39:53<4:31:33, 17.07 examples/s][E ProcessGroupGloo.cpp:138] Rank 3 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
Traceback (most recent call last):
File "fine-tune.py", line 305, in
train()
File "fine-tune.py", line 232, in train
dist.monitored_barrier(timeout=datetime.timedelta(minutes=600))
File "/home/developer/pytorch/torch/distributed/distributed_c10d.py", line 3401, in monitored_barrier
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
RuntimeError: Rank 3 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
Original exception:
[/home/developer/pytorch/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:133] Timed out waiting 24000000ms for send operation to complete
Map (num_proc=14): 70%|██████████████████████████████████████████████████████████████████▌ | 652391/930514 [6:39:54<4:45:07, 16.26 examples/s][E ProcessGroupGloo.cpp:138] Rank 1 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
Traceback (most recent call last):
File "fine-tune.py", line 305, in
train()
File "fine-tune.py", line 232, in train
dist.monitored_barrier(timeout=datetime.timedelta(minutes=600))
File "/home/developer/pytorch/torch/distributed/distributed_c10d.py", line 3401, in monitored_barrier
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
RuntimeError: Rank 1 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
Original exception:
[/home/developer/pytorch/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:133] Timed out waiting 24000000ms for send operation to complete
Map (num_proc=14): 70%|██████████████████████████████████████████████████████████████████▌ | 652417/930514 [6:39:56<7:05:10, 10.90 examples/s][E ProcessGroupGloo.cpp:138] Rank 2 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
Traceback (most recent call last):
File "fine-tune.py", line 305, in
train()
File "fine-tune.py", line 232, in train
dist.monitored_barrier(timeout=datetime.timedelta(minutes=600))
File "/home/developer/pytorch/torch/distributed/distributed_c10d.py", line 3401, in monitored_barrier
return group_to_use.monitored_barrier(timeout, wait_all_ranks=wait_all_ranks)
RuntimeError: Rank 2 successfully reached monitoredBarrier, but received errors while waiting for send/recv from rank 0. Please check rank 0 logs for faulty rank.
Original exception:
[/home/developer/pytorch/third_party/gloo/gloo/transport/tcp/unbound_buffer.cc:133] Timed out waiting 24000000ms for send operation to complete
Map (num_proc=14): 70%|██████████████████████████████████████████████████████████████████▌ | 652430/930514 [6:39:57<6:55:04, 11.17 examples/s][2023-10-17 06:15:53,453] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 47356
[2023-10-17 06:15:54,809] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 47357
[2023-10-17 06:15:54,810] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 47358
[2023-10-17 06:15:54,811] [INFO] [launch.py:315:sigkill_handler] Killing subprocess 47390

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

3 participants