Skip to content

Commit

Permalink
Incorporate count_all fix
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Addair <tgaddair@gmail.com>
  • Loading branch information
tgaddair committed May 15, 2021
1 parent 38d8ae0 commit 138b1a5
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions horovod/torch/sync_batch_norm.py
Expand Up @@ -168,9 +168,8 @@ def backward(self, grad_output):

if _SYNC_BN_V4:
# from 1.9.0 on we need a count tensor on all devices
count_all_handle = allreduce_async(count_all, op=Sum, name='sync_batch_norm.count_all')
count_all = synchronize(count_all_handle)
count_all = count_all.view(-1).int().to(grad_output.device)
# count_all is calculated as total count across all ranks in forward function
count_all = count_all.to(dtype=torch.int, device=grad_output.device)
elif _SYNC_BN_V2 or _SYNC_BN_V3:
# before 1.9.0 we need the count as an integer to compute means values
count = count_all.sum()
Expand Down

0 comments on commit 138b1a5

Please sign in to comment.