You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hello, The horovod0.21.0 support grouped allreduce to enable more efficient tensor fusion and deterministic training. Will the group mechanism be migrated to allgather?
The situation is that I find the forward of the SyncBatchnorm executes three allgather (count ,mean and std) will cost much time, since this three allgather will generate three requests in horovod and grouped_allgather only generates one request (as explained grouped_allreduce).
# calculate mean/invstd for input.mean, invstd=torch.batch_norm_stats(input, eps)
count_handle=allgather_async(count.unsqueeze(0), name='sync_batch_norm.count')
mean_handle=allgather_async(mean.unsqueeze(0), name='sync_batch_norm.mean')
invstd_handle=allgather_async(invstd.unsqueeze(0), name='sync_batch_norm.invstd')
# wait on the async communication to finishcount_all=synchronize(count_handle)
mean_all=synchronize(mean_handle)
invstd_all=synchronize(invstd_handle)
In backward, I can replace two allreduces with grouped_allreduce for sum_dy and sum_dy_xmu, such bellow:
# The original implementation by allreducesum_dy_handle=allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy')
sum_dy_xmu_handle=allreduce_async(sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')
sum_dy=synchronize(sum_dy_handle)
sum_dy_xmu=synchronize(sum_dy_xmu_handle)
# The implementation by grouped_allreducesum_dy, sum_dy_xmu=grouped_allreduce([sum_dy, sum_dy_xmu], op=Sum, name='sync_batch_norm.sum_dy_and_dy_xmu')
For the implementation of apex SyncBatchnorm, which use grouped_allgather in foward and grouped_allreduce in backward via directly using torch.distributed module. In my experiment, the performance of apex SyncBatchnorm is better than horovod.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Hello, The horovod0.21.0 support grouped allreduce to enable more efficient tensor fusion and deterministic training. Will the group mechanism be migrated to allgather?
The situation is that I find the forward of the SyncBatchnorm executes three allgather (count ,mean and std) will cost much time, since this three allgather will generate three requests in horovod and grouped_allgather only generates one request (as explained grouped_allreduce).
In backward, I can replace two allreduces with grouped_allreduce for sum_dy and sum_dy_xmu, such bellow:
For the implementation of apex SyncBatchnorm, which use grouped_allgather in foward and grouped_allreduce in backward via directly using torch.distributed module. In my experiment, the performance of apex SyncBatchnorm is better than horovod.
The text was updated successfully, but these errors were encountered: