New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sync Batch Norm for PyTorch #1923
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome! Just a few nits.
horovod/torch/sync_batch_norm.py
Outdated
self.eps, self.momentum) | ||
|
||
def forward(self, input): | ||
# currently only GPU input is supported |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a TODO stating what would be needed to get it work on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is constrained by PyTorch internal kernels that we use. They're only implemented on GPU. If we want to make it work on CPU, we should consider rolling our own implementation w/o dependency on PyTorch internals.
@@ -1667,6 +1667,59 @@ def test_horovod_join_broadcast(self): | |||
ret = hvd.join(hvd.local_rank()) | |||
else: | |||
ret = hvd.join() | |||
|
|||
def test_horovod_sync_batch_norm(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May also want to include a skip if Python version is < 3 (integration tests passing because we have no Python 2 GPU tests anymore).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added skip, but we should deprecate all Py2 tests :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Tentative plan is to drop Python 2 in 0.20.0.
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
Signed-off-by: Alex Sergeev <alexander.sergeev@live.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Let's merge it!
count_handle = allgather_async(count.unsqueeze(0), name='sync_batch_norm.count') | ||
mean_handle = allgather_async(mean.unsqueeze(0), name='sync_batch_norm.mean') | ||
invstd_handle = allgather_async(invstd.unsqueeze(0), name='sync_batch_norm.invstd') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would have a lot of BatchNorm
instances in a network. Would we need to add the instance id
to the name
arguments in allreduce_async
calls here and below (line 157, 158) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't seem necessary so far - current implementation of SyncBN requires each worker to execute model in the same sequence with respect to batch norms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got you thanks!
# before 1.6.0, sum_dy was sum of means from every worker, so we just | ||
# need to divide it by number of workers | ||
mean_dy = sum_dy / size() | ||
mean_dy_xmu = sum_dy_xmu / size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should divide by count_all.sum()
for both pytorch 1.5 and 1.6.
For pytorch 1.5, right now if I comment out _run_bn
code path (not using torch.nn.BatchNorm) and run the following on v0.19.3
locally:
torch.cuda.manual_seed(2020)
bn = torch.nn.BatchNorm2d(10).cuda()
bn_hvd = SyncBatchNorm(10).cuda()
x = torch.rand(3, 10, 8, 8).cuda()
x1 = x.clone().requires_grad_()
x2 = x.clone().requires_grad_()
y = bn(x1)
y_hvd = bn_hvd(x2)
y.sum().backward()
y_hvd.sum().backward()
print((x1.grad - x2.grad).abs().sum())
I got
tensor(1298650., device='cuda:0')
But if we use count_all.sum()
here, the result is more reasonable:
tensor(0.0004, device='cuda:0')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, I've raised #1980 to address this. Apparently they made breaking changes in both 1.5.0 and 1.6.0.
Does SyncBatchNorm support the use case where I have 2 nodes, each with 8 GPUs, and the sync only happens intra-node instead of inter-node (sync 2 groups of 8 workers instead of 16) |
@eric-haibin-lin, this implementation does not, but please feel free to extend it! |
@alsrgv @eric-haibin-lin I think the main limitation here is that Horovod still is currently designed around performing collectives across the global communicator, with no options currently to perform collectives on subsets of workers. @tgaddair, perhaps this is something we should begin to think about. |
Good point, @romerojosh. Added #2139 to track. |
Implementation of https://pytorch.org/docs/stable/nn.html#syncbatchnorm using Horovod. Current version uses optimized CUDA kernels written for PyTorch and have two different invocations because they changed. We can consider making our own implementation if it turns out to be a hassle.
Why Sync Batch Norm?
As evidenced by https://arxiv.org/abs/1804.07612, small batches are great for training neural networks. However, https://arxiv.org/abs/1803.08494 noted that multi-GPU training with small-batch BatchNorm is detrimental to performance.
SyncBatchNorm improves training where each worker can only hold few examples (1..4) and total number of workers is not too high - at which point it starts to lose its regularization abilities.
Fixes #1384