Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Is torch.distributed.all_reduce working as expected? #8

Closed
WarBean opened this issue Mar 23, 2021 · 8 comments
Closed

Is torch.distributed.all_reduce working as expected? #8

WarBean opened this issue Mar 23, 2021 · 8 comments

Comments

@WarBean
Copy link

WarBean commented Mar 23, 2021

This line https://github.com/facebookresearch/barlowtwins/blob/main/main.py#L208 use torch.distributed.all_reduce to sum the correlation matrices across all gpus. However as I know this op is not dedicated for forward computation where backward computation would run later. Instead, to apply "correctly differentiable" distributed all reduce, the official PyTorch document recommends using torch.distributed.nn.*: https://pytorch.org/docs/stable/distributed.html#autograd-enabled-communication-primitives

@WarBean
Copy link
Author

WarBean commented Mar 23, 2021

BTW, is torch.distributed.allgather more reasonable here? It can gather features from all gpus to construct a larger sample dimension, which may make the correlation estimation more accurate.

@WarBean WarBean closed this as completed Mar 23, 2021
@WarBean WarBean reopened this Mar 23, 2021
@jzbontar
Copy link
Contributor

jzbontar commented Mar 23, 2021

This line https://github.com/facebookresearch/barlowtwins/blob/main/main.py#L208 use torch.distributed.all_reduce to sum the correlation matrices across all gpus. However as I know this op is not dedicated for forward computation where backward computation would run later. Instead, to apply "correctly differentiable" distributed all reduce, the official PyTorch document recommends using torch.distributed.nn.*: https://pytorch.org/docs/stable/distributed.html#autograd-enabled-communication-primitives

That's a good question. I already replied to a similar question here. TL;DR I believe that torch.distributed.all_reduce is correct, but torch.distributed.nn.all_reduce is cleaner and it agrees with the PyTorch documentation. We will be changing our code to use torch.distributed.nn.all_reduce. Thanks for pointing it out!

BTW, is torch.distributed.allgather more reasonable here? It can gather features from all gpus to construct a larger sample dimension, which may make the correlation estimation more accurate.

I believe that the all_gather and all_reduce solutions are equivalent. Both compute the cross-correlation matrix using "a large sample dimension". Consider the following code (when run on a machine with 2 GPUs):

import torch

n, m = 128, 4

def main_worker(gpu):
    torch.distributed.init_process_group(backend='nccl', init_method='tcp://localhost:12345', world_size=2, rank=gpu)
    torch.cuda.set_device(gpu)
    x = torch.rand(n, m).cuda()

    # compute xTx using all_reduce
    xTx = x.T @ x
    torch.distributed.all_reduce(xTx)
    if gpu == 0:
        print(xTx)

    # compute xTx using all_gather
    ys = [torch.empty(n, m).cuda() for _ in range(2)]
    torch.distributed.all_gather(ys, x)
    y = torch.cat(ys)
    yTy = y.T @ y
    if gpu == 0:
        print(yTy)

if __name__ == '__main__':
    torch.multiprocessing.spawn(main_worker, nprocs=2)

Both the all_reduce and the all_gather solution return exactly the same result.

@WarBean
Copy link
Author

WarBean commented Mar 24, 2021

After doing some tests with differentiable allgather I realize that your implementation is an equivalent version. Very smart tricks.

@WarBean WarBean closed this as completed Mar 24, 2021
@zhangdan8962
Copy link

@jzbontar Hi, I have a quick question here. I am wondering what is the purpose for these two lines:

barlowtwins/main.py

Lines 207 to 208 in 52ffe56

c.div_(self.args.batch_size)
torch.distributed.all_reduce(c)

Is it simply because you are using global BN?

@jzbontar
Copy link
Contributor

@zhangdan8962, those two lines are there to handle training on multiple gpus. In the multi-gpu setting, each gpu computes a "local" cross-correlation matrix from a subset of the examples in a batch (for example, when training with a batch size of 2048 on 32 gpus, each gpu computes the cross-correlation matrix from 64 training examples). The call to all_reduce sums all of the "local" cross-correlation matrices. Computing the cross-correlation matrix in this way is the same as computing the cross-correlation matrix on one gpu and using the entire batch.

@zhangdan8962
Copy link

zhangdan8962 commented Apr 25, 2021

@jzbontar Thank you again for such detailed explanation. And I guess this is also the reason for using global Batch Normalization?

Also I have a final question regarding this line:

torch.distributed.reduce(loss.div_(args.world_size), 0)

I thought the loss on each GPU would be the same since they all have the same global cross-correlation matrix. So what is the point for reducing average loss to the first GPU?

Thank you in advance for the help!

@jzbontar
Copy link
Contributor

And I guess this is also the reason for using global Batch Normalization?

Yes, same reason. You want the output to be the same regardless of whether the code was run on one or many gpus.

I thought the loss on each GPU would be the same since they all have the same global cross-correlation matrix. So what is the point for reducing average loss to the first GPU?

Good catch. That particular reduce is not needed since, as you pointed out, the loss will be the same on all gpus. I removed that line from main.py.

@DrJimFan
Copy link

I can confirm that torch.distributed.nn.all_reduce is mathematically incorrect: pytorch/pytorch#58005
torch.distributed.all_reduce is correct, but seems to be by accident rather than by design.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants