Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ShardedDataParallel doesn't work with multiple nodes #397

Closed
kamo-naoyuki opened this issue Feb 18, 2021 · 19 comments
Closed

ShardedDataParallel doesn't work with multiple nodes #397

kamo-naoyuki opened this issue Feb 18, 2021 · 19 comments
Assignees
Labels
bug Something isn't working

Comments

@kamo-naoyuki
Copy link

kamo-naoyuki commented Feb 18, 2021

🐛 Bug

ShardedDataParallel successfully works with 8gpus x 1nodes, while got the following error with 8gpus x 2 nodes

AssertionError: A bucket failed to be sent, probably unused parameters.Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-

However, obviously there are not unused parameters. Note that torch.nn.DistributedDataParallel can work with same environment.

To Reproduce

import os
import sys 

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from fairscale.optim.oss import OSS 
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP


def train(rank, local_rank, world_size, init_method):
    print("DDP init", world_size, rank, local_rank, init_method, file=sys.stderr)
    dist.init_process_group(backend='nccl', init_method=init_method, rank=rank, world_size=world_size)
    torch.cuda.set_device(local_rank)

    model = torch.nn.Linear(3, 3).cuda()
    base_optimizer = torch.optim.Adam
    base_optimizer_arguments = {}

    optimizer = OSS(params=model.parameters(), optim=base_optimizer, **base_optimizer_arguments)
    model = ShardedDDP(model, optimizer)

    print("Train", file=sys.stderr)
    model.train()
    model.zero_grad()
    outputs = model(torch.randn(2,3).cuda())
    loss = outputs.sum()
    loss.backward()
    optimizer.step()
    print("finish", file=sys.stderr)

I used shared file system initialization for init_method.

init_method="file:..."

Environment

Note that I tested tcp and infiniband connection both.

PyTorch version: 1.7.1
Is debug build: False
CUDA used to build PyTorch: 11.0
ROCM used to build PyTorch: N/A

OS: CentOS Linux 7 (Core) (x86_64)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-44)
Clang version: Could not collect
CMake version: version 2.8.12.2

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB
GPU 4: Tesla V100-SXM2-32GB
GPU 5: Tesla V100-SXM2-32GB
GPU 6: Tesla V100-SXM2-32GB
GPU 7: Tesla V100-SXM2-32GB

Nvidia driver version: 450.51.06
cuDNN version: Probably one of the following:
/usr/lib64/libcudnn.so.8.0.4
/usr/lib64/libcudnn_adv_infer.so.8.0.4
/usr/lib64/libcudnn_adv_train.so.8.0.4
/usr/lib64/libcudnn_cnn_infer.so.8.0.4
/usr/lib64/libcudnn_cnn_train.so.8.0.4
/usr/lib64/libcudnn_ops_infer.so.8.0.4
/usr/lib64/libcudnn_ops_train.so.8.0.4
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] pytorch-ranger==0.1.1
[pip3] pytorch-wpe==0.0.0
[pip3] torch==1.7.1
[pip3] torch-complex==0.2.0
[pip3] torch-optimizer==0.0.1a17
[pip3] torchaudio==0.7.2
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               11.0.221             h6bb024c_0  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.2.0            py38h23d657b_0  
[conda] mkl_random                1.1.1            py38h0573a6f_0  
[conda] numpy                     1.19.2           py38h54aff64_0  
[conda] numpy-base                1.19.2           py38hfa32c7d_0  
[conda] pytorch                   1.7.1           py3.8_cuda11.0.221_cudnn8.0.5_0    pytorch
[conda] pytorch-ranger            0.1.1                    pypi_0    pypi
[conda] pytorch-wpe               0.0.0                    pypi_0    pypi
[conda] torch-complex             0.2.0                    pypi_0    pypi
[conda] torch-optimizer           0.0.1a17                 pypi_0    pypi
[conda] torchaudio                0.7.2                    pypi_0    pypi
@kamo-naoyuki kamo-naoyuki changed the title Multiple nodes training doesn't work ShardedDataParallel doesn't work with multiple nodes Feb 18, 2021
@blefaudeux
Copy link
Contributor

blefaudeux commented Feb 18, 2021

So, the assert is telling you what's wrong, you need to find which parameter is not being used, or deactivate the buckets by passing a reduce_buffer_size == 0 at init time.
edit: ah no sorry, it should not happen with this model obviously.. Can you confirm that it happens with the current master ?

@blefaudeux blefaudeux added documentation Improvements or additions to documentation bug Something isn't working and removed documentation Improvements or additions to documentation labels Feb 18, 2021
@blefaudeux
Copy link
Contributor

Adding a unit test so that future releases are not broken, sorry about that
#398

@blefaudeux
Copy link
Contributor

blefaudeux commented Feb 18, 2021

@kamo-naoyuki , could you confirm that this is not on master ?

@kamo-naoyuki
Copy link
Author

I tested master version and the error disappears, but I got another error.

misc/argcheck.cc:39 NCCL WARN Reduce : invalid root -1 (root should be in the 0..16 range)

Note that torch.nn.DistributedDataParallel can work.
You just added a unittest on single node, but did you test multiple nodes yourself using NCCL?

@blefaudeux blefaudeux mentioned this issue Feb 19, 2021
4 tasks
@blefaudeux
Copy link
Contributor

I tested master version and the error disappears, but I got another error.

misc/argcheck.cc:39 NCCL WARN Reduce : invalid root -1 (root should be in the 0..16 range)

Note that torch.nn.DistributedDataParallel can work.
You just added a unittest on single node, but did you test multiple nodes yourself using NCCL?

yes of course, just not on circleci so it's not visible here, but I did not get this issue. There's no root parameter for reduce, this is strange, can you test 0.1.7 just released ?

@blefaudeux
Copy link
Contributor

ah wait, there are just not enough parameters in your example, so some shards are just empty, that's why. Just try with a bigger model, and I can add an assert for that, but I'm guessing that it does not happen very often. DDP does not shard, that's why you were not seeing that

@blefaudeux
Copy link
Contributor

closing because the title is not really reflecting that, I'll create an issue to assert if some shards are empty

@kamo-naoyuki
Copy link
Author

I got same error with the large model.

    model = torch.nn.Sequential(
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
    ).cuda()

It's better to test my code. You seem to just read the code and think the reason.

@kamo-naoyuki
Copy link
Author

Also note that

  • It works with 1 node with 8gpus.
  • It works with 2 nodes with 2gpus for each node .
  • The error happens with 2 nodes with 8gpus for each node .
  • This error happens both with reduce_buffer_size=0 and reduce_buffer_size != 0

Obviously there are some bugs.

@blefaudeux blefaudeux reopened this Feb 20, 2021
@blefaudeux
Copy link
Contributor

blefaudeux commented Feb 20, 2021

It's better to test my code. You seem to just read the code and think the reason.

has it occurred to you that maybe I cannot test your code on multiple nodes easily ? fairscale is used all the time on multiple nodes, just not your code, but it should not change much, this error was never raised. Empty shards was a real corner case which your first example created, now I don't know what happens except that root is not a parameter for a reduce, so it makes little sense from the python API. A quick lookup surfaces this NVIDIA/nccl#276. Could you try with the gloo backend ?

Obviously there are some bugs.

this we can agree on, FYI NCCL bugs are also a thing

@blefaudeux
Copy link
Contributor

blefaudeux commented Feb 20, 2021

cc @mrshenli if that rings a bell, the error points to https://github.com/NVIDIA/nccl/blob/master/src/misc/argcheck.cc#L39

@blefaudeux
Copy link
Contributor

no problem with your "local_rank" variable by the way ?

@kamo-naoyuki
Copy link
Author

no problem with your "local_rank" variable by the way ?

What do you mean?

@kamo-naoyuki
Copy link
Author

Gloo doesn't work.

  File "a.py", line 45, in train
    optimizer.step()
  File ".../fairscale/fairscale/optim/oss.py", line 238, in step
    self._broadcast_params()
  File ".../site-packages/torch/autograd/grad_mode.py", line 26, in decorate_context
    return func(*args, **kwargs)
  File ".../fairscale/fairscale/optim/oss.py", line 555, in _broadcast_params
    last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True)
  File ".../torch/distributed/distributed_c10d.py", line 860, in broadcast
    work = _default_pg.broadcast([tensor], opts)
ValueError: ProcessGroupGloo::broadcast: invalid root rank: 16

I debugged why this happens.

Here referring the size of self.buckets[device]

for src_rank, bucket in enumerate(self.buckets[device]):

but in this case i got len(self.buckets[device]) = 17. This is caused by the following since you don't check the length here.

else:
self.buckets[device].append(torch.zeros(1, device=device))

This caused by the following:

for param_group in self.param_groups:
for param in param_group["params"]:
device = param.device
if self._per_device_params.get(device) is None:
self._per_device_params[device] = [[] for _ in range(self.world_size)]
self._per_device_params[device][self.param_to_rank[param]] += [param]

There are ranks which don't have any parameters, thus still there are empty lists, so len(params) can be 0 for the case, then it causes more rank than world_size.

@blefaudeux
Copy link
Contributor

blefaudeux commented Feb 20, 2021

"There are ranks which don't have parameters" which is the reason I mentioned above and why I wanted to add an assert which would have told you so (this will land after the week end #406). But thanks for the debugging !

@blefaudeux
Copy link
Contributor

I have to admit that the fact that it shows as a nccl error was and is strange to me, I'm guessing that the broadcast buckets are the reason, I should have guarded that earlier

@blefaudeux
Copy link
Contributor

Ok to close ?

@kamo-naoyuki
Copy link
Author

How can I solve this issue?

@kamo-naoyuki
Copy link
Author

Okay, I understood.
You mean ShardedDataParallel needs more parameter objects in the model than world_size? The model size is nothing related to this issue.

 model = torch.nn.Sequential(
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
        torch.nn.Linear(10000, 10000),
    )

This model is large, but has only 14 parameters.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants