Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whether to support grouped_allgather in the future? #3325

Closed
wuyujiji opened this issue Dec 16, 2021 · 3 comments
Closed

Whether to support grouped_allgather in the future? #3325

wuyujiji opened this issue Dec 16, 2021 · 3 comments
Labels

Comments

@wuyujiji
Copy link

wuyujiji commented Dec 16, 2021

Hello, The horovod0.21.0 support grouped allreduce to enable more efficient tensor fusion and deterministic training. Will the group mechanism be migrated to allgather?

The situation is that I find the forward of the SyncBatchnorm executes three allgather (count ,mean and std) will cost much time, since this three allgather will generate three requests in horovod and grouped_allgather only generates one request (as explained grouped_allreduce).

  # calculate mean/invstd for input.
  mean, invstd = torch.batch_norm_stats(input, eps)

  count_handle = allgather_async(count.unsqueeze(0), name='sync_batch_norm.count')
  mean_handle = allgather_async(mean.unsqueeze(0), name='sync_batch_norm.mean')
  invstd_handle = allgather_async(invstd.unsqueeze(0), name='sync_batch_norm.invstd')

  # wait on the async communication to finish
  count_all = synchronize(count_handle)
  mean_all = synchronize(mean_handle)
  invstd_all = synchronize(invstd_handle)

In backward, I can replace two allreduces with grouped_allreduce for sum_dy and sum_dy_xmu, such bellow:

# The original implementation by allreduce
sum_dy_handle = allreduce_async(sum_dy, op=Sum, name='sync_batch_norm.sum_dy')
sum_dy_xmu_handle = allreduce_async(sum_dy_xmu, op=Sum, name='sync_batch_norm.sum_dy_xmu')
sum_dy = synchronize(sum_dy_handle)
sum_dy_xmu = synchronize(sum_dy_xmu_handle)

# The implementation by grouped_allreduce
sum_dy, sum_dy_xmu = grouped_allreduce([sum_dy, sum_dy_xmu], op=Sum, name='sync_batch_norm.sum_dy_and_dy_xmu')

For the implementation of apex SyncBatchnorm, which use grouped_allgather in foward and grouped_allreduce in backward via directly using torch.distributed module. In my experiment, the performance of apex SyncBatchnorm is better than horovod.

# foward
count_t = torch.empty(1, dtype=mean.dtype, device=mean.device).fill_(count)
combined = torch.cat([mean.view(-1), var_biased.view(-1), count_t], dim=0)
combined_list = [torch.empty_like(combined) for k in range(world_size)]
torch.distributed.all_gather(combined_list, combined, process_group)
combined = torch.stack(combined_list, dim=0)
mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1)

# backward
num_channels = sum_dy.shape[0]
combined = torch.cat([sum_dy, sum_dy_xmu], dim=0)
torch.distributed.all_reduce(combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False)
sum_dy, sum_dy_xmu = torch.split(combined, num_channels)
@wuyujiji
Copy link
Author

@tgaddair @romerojosh

@stale
Copy link

stale bot commented Feb 18, 2022

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Feb 18, 2022
@stale stale bot closed this as completed Feb 25, 2022
@maxhgerlach
Copy link
Collaborator

@wuyujiji, I've started to work on this in #3594. Support for PyTorch is still missing, though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

No branches or pull requests

2 participants