-
Notifications
You must be signed in to change notification settings - Fork 124
Is torch.distributed.all_reduce working as expected? #8
Comments
BTW, is |
That's a good question. I already replied to a similar question here. TL;DR I believe that
I believe that the import torch
n, m = 128, 4
def main_worker(gpu):
torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:12345', world_size=2, rank=gpu)
torch.cuda.set_device(gpu)
x = torch.rand(n, m).cuda()
# compute xTx using all_reduce
xTx = x.T @ x
torch.distributed.all_reduce(xTx)
if gpu == 0:
print(xTx)
# compute xTx using all_gather
ys = [torch.empty(n, m).cuda() for _ in range(2)]
torch.distributed.all_gather(ys, x)
y = torch.cat(ys)
yTy = y.T @ y
if gpu == 0:
print(yTy)
if __name__ == '__main__':
torch.multiprocessing.spawn(main_worker, nprocs=2) Both the |
After doing some tests with differentiable allgather I realize that your implementation is an equivalent version. Very smart tricks. |
@zhangdan8962, those two lines are there to handle training on multiple gpus. In the multi-gpu setting, each gpu computes a "local" cross-correlation matrix from a subset of the examples in a batch (for example, when training with a batch size of 2048 on 32 gpus, each gpu computes the cross-correlation matrix from 64 training examples). The call to |
@jzbontar Thank you again for such detailed explanation. And I guess this is also the reason for using global Batch Normalization? Also I have a final question regarding this line: Line 128 in 7b1baec
I thought the loss on each GPU would be the same since they all have the same global cross-correlation matrix. So what is the point for reducing average loss to the first GPU? Thank you in advance for the help! |
Yes, same reason. You want the output to be the same regardless of whether the code was run on one or many gpus.
Good catch. That particular reduce is not needed since, as you pointed out, the loss will be the same on all gpus. I removed that line from main.py. |
I can confirm that |
This line https://github.com/facebookresearch/barlowtwins/blob/main/main.py#L208 use
torch.distributed.all_reduce
to sum the correlation matrices across all gpus. However as I know this op is not dedicated for forward computation where backward computation would run later. Instead, to apply "correctly differentiable" distributed all reduce, the official PyTorch document recommends usingtorch.distributed.nn.*
: https://pytorch.org/docs/stable/distributed.html#autograd-enabled-communication-primitivesThe text was updated successfully, but these errors were encountered: