Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync Batch Norm for PyTorch #1923

Merged
merged 6 commits into from Apr 30, 2020
Merged

Sync Batch Norm for PyTorch #1923

merged 6 commits into from Apr 30, 2020

Conversation

alsrgv
Copy link
Member

@alsrgv alsrgv commented Apr 29, 2020

Implementation of https://pytorch.org/docs/stable/nn.html#syncbatchnorm using Horovod. Current version uses optimized CUDA kernels written for PyTorch and have two different invocations because they changed. We can consider making our own implementation if it turns out to be a hassle.

Why Sync Batch Norm?

As evidenced by https://arxiv.org/abs/1804.07612, small batches are great for training neural networks. However, https://arxiv.org/abs/1803.08494 noted that multi-GPU training with small-batch BatchNorm is detrimental to performance.

SyncBatchNorm improves training where each worker can only hold few examples (1..4) and total number of workers is not too high - at which point it starts to lose its regularization abilities.

Fixes #1384

Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome! Just a few nits.

horovod/torch/sync_batch_norm.py Outdated Show resolved Hide resolved
horovod/torch/sync_batch_norm.py Outdated Show resolved Hide resolved
horovod/torch/sync_batch_norm.py Outdated Show resolved Hide resolved
horovod/torch/sync_batch_norm.py Outdated Show resolved Hide resolved
self.eps, self.momentum)

def forward(self, input):
# currently only GPU input is supported
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a TODO stating what would be needed to get it work on CPU?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is constrained by PyTorch internal kernels that we use. They're only implemented on GPU. If we want to make it work on CPU, we should consider rolling our own implementation w/o dependency on PyTorch internals.

horovod/torch/sync_batch_norm.py Show resolved Hide resolved
@@ -1667,6 +1667,59 @@ def test_horovod_join_broadcast(self):
ret = hvd.join(hvd.local_rank())
else:
ret = hvd.join()

def test_horovod_sync_batch_norm(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May also want to include a skip if Python version is < 3 (integration tests passing because we have no Python 2 GPU tests anymore).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added skip, but we should deprecate all Py2 tests :-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Tentative plan is to drop Python 2 in 0.20.0.

Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Copy link
Collaborator

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's merge it!

@tgaddair tgaddair merged commit 2a3f43f into master Apr 30, 2020
@tgaddair tgaddair deleted the sync_bn branch April 30, 2020 20:12
Comment on lines +105 to +107
count_handle = allgather_async(count.unsqueeze(0), name='sync_batch_norm.count')
mean_handle = allgather_async(mean.unsqueeze(0), name='sync_batch_norm.mean')
invstd_handle = allgather_async(invstd.unsqueeze(0), name='sync_batch_norm.invstd')
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would have a lot of BatchNorm instances in a network. Would we need to add the instance id to the name arguments in allreduce_async calls here and below (line 157, 158) ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't seem necessary so far - current implementation of SyncBN requires each worker to execute model in the same sequence with respect to batch norms.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got you thanks!

Comment on lines +168 to +171
# before 1.6.0, sum_dy was sum of means from every worker, so we just
# need to divide it by number of workers
mean_dy = sum_dy / size()
mean_dy_xmu = sum_dy_xmu / size()
Copy link

@thuyen thuyen May 25, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should divide by count_all.sum() for both pytorch 1.5 and 1.6.

For pytorch 1.5, right now if I comment out _run_bn code path (not using torch.nn.BatchNorm) and run the following on v0.19.3 locally:

     torch.cuda.manual_seed(2020)
     bn = torch.nn.BatchNorm2d(10).cuda()
     bn_hvd = SyncBatchNorm(10).cuda()

     x = torch.rand(3, 10, 8, 8).cuda()
     x1 = x.clone().requires_grad_()
     x2 = x.clone().requires_grad_()
     y = bn(x1)
     y_hvd = bn_hvd(x2)

     y.sum().backward()
     y_hvd.sum().backward()
     print((x1.grad - x2.grad).abs().sum())

I got

tensor(1298650., device='cuda:0')

But if we use count_all.sum() here, the result is more reasonable:

tensor(0.0004, device='cuda:0')

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, I've raised #1980 to address this. Apparently they made breaking changes in both 1.5.0 and 1.6.0.

@eric-haibin-lin
Copy link
Collaborator

Does SyncBatchNorm support the use case where I have 2 nodes, each with 8 GPUs, and the sync only happens intra-node instead of inter-node (sync 2 groups of 8 workers instead of 16)

@alsrgv
Copy link
Member Author

alsrgv commented Jul 13, 2020

@eric-haibin-lin, this implementation does not, but please feel free to extend it!

@romerojosh
Copy link
Collaborator

@alsrgv @eric-haibin-lin I think the main limitation here is that Horovod still is currently designed around performing collectives across the global communicator, with no options currently to perform collectives on subsets of workers. @tgaddair, perhaps this is something we should begin to think about.

@tgaddair
Copy link
Collaborator

Good point, @romerojosh. Added #2139 to track.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Development

Successfully merging this pull request may close these issues.

Sync Batch Norm for Pytorch
5 participants