Skip to content

Commit

Permalink
Don't sum the counts
Browse files Browse the repository at this point in the history
  • Loading branch information
EnricoMi committed May 14, 2021
1 parent c5cf6fb commit 687ae0d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions horovod/torch/sync_batch_norm.py
Expand Up @@ -168,9 +168,10 @@ def backward(self, grad_output):

if _SYNC_BN_V4:
# from 1.9.0 on we need a count tensor on all devices
count_all_handle = allreduce_async(count_all, op=Sum, name='sync_batch_norm.count_all')
count_all = synchronize(count_all_handle)
count_all = count_all.view(-1).int().to(grad_output.device)
count = torch.tensor([count_all])
count_handle = allgather_async(count.unsqueeze(0), name='sync_batch_norm.count')
count = synchronize(count_handle)
counts_for_bnbe = count.view(-1).int().to(grad_output.device)
elif _SYNC_BN_V2 or _SYNC_BN_V3:
# before 1.9.0 we need the count as an integer to compute means values
count = count_all.sum()
Expand All @@ -192,7 +193,7 @@ def backward(self, grad_output):
weight,
sum_dy,
sum_dy_xmu,
count_all
counts_for_bnbe
)
else:
# before 1.9.0, mean parameters expected, not sums and count
Expand Down

0 comments on commit 687ae0d

Please sign in to comment.