Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support overlapping NCCL collective communication with compute on GPU #4666

Closed
yxd886 opened this issue Oct 21, 2020 · 8 comments
Closed

Support overlapping NCCL collective communication with compute on GPU #4666

yxd886 opened this issue Oct 21, 2020 · 8 comments
Assignees
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR XLA

Comments

@yxd886
Copy link

yxd886 commented Oct 21, 2020

Hello, I am conducting an experiment using a server with 4 GPUs. I just run the script spmd_mnist_classifier_fromscratch.py under the example folder.

Then I used Nvidia system to profile the process, the picture below shows the timeline of GPU kernels.

image

It seems that the ncclAllreduce kernel is executed sequentially with other computation kernels without any overlap. Is this the real case? If so, I am wondering how to improve this.

Thanks very much.

@hawkinsp
Copy link
Member

Currently we don't support overlapping NCCL reductions with compute. There's some code for it in XLA (using multiple CUDA streams), but it is buggy (produces wrong outputs) and not enabled. We could work on fixing it but it simply hasn't made it to the top of the priority list.

However, I've usually found that doesn't matter a whole lot at single-host scales. I think the NCCL all-reduce is sticking out in the MNIST example because it doesn't actually have enough compute to keep the GPUs busy. A small MNIST model doesn't make sense to run on multiple GPUs in the first place and the example you are looking at is intended more for explanatory purposes than as something you might actually want to run. Try a more realistic model?

For example, I often use the Flax mixed-precision Resnet-50 model for multi-GPU benchmarking (https://github.com/google/flax/blob/master/examples/imagenet/README.md). On an 8xV100 VM I observe that model to take around 3.5% of execution time on all-reduce operations, which while not nothing means the benefits of overlapping communication and compute would be relatively small.

What do you think?

@hawkinsp hawkinsp changed the title Can the computation of BP and the ncclAllreduce be overlapped when using pmap and psum? Support overlapping NCCL collective communication with compute on GPU Oct 21, 2020
@hawkinsp hawkinsp added the enhancement New feature or request label Oct 21, 2020
@yxd886
Copy link
Author

yxd886 commented Oct 22, 2020

Dear hawkinsp,
Thanks for your prompt reply.
I still think that the overlapping of allreduce communication and computation is very important. It is true that for tiny models such as mnist, it is unnecessary to consider this. However, for models with a large size of parameters (Resnet200, Bert-large, XLnet, etc.), The training speed will be seriously impacted if there is no overlapping of gradients aggregation and computation. There are also large amount of work focusing on accelerating the training speed by maximizing the overlap of gradients aggregation and computation such as A Generic Communication Scheduler for Distributed DNN Training Acceleration (SOSP 19).

I think it is a very important feature to implement a training framework for high-performance machine learning research.
Thanks very much.

@byronyi
Copy link

byronyi commented Oct 22, 2020

@cloudhan
Copy link
Contributor

@byronyi that patch is only in nvidia's fork, any chance will they merge it into upstream tf2?

@cloudhan

This comment was marked as off-topic.

@hawkinsp hawkinsp removed their assignment May 14, 2021
@apaszke apaszke added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Jun 2, 2021
@cicirori
Copy link

cicirori commented Dec 1, 2021

I think this optimization pass solves this problem to some extent: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/async_collective_creator.cc

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Aug 10, 2022
@hawkinsp hawkinsp added the XLA label Aug 12, 2022
@sudhakarsingh27 sudhakarsingh27 removed the P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) label Sep 21, 2022
@jon-chuang
Copy link
Contributor

Hello, to what extent is collective communication overlap scheduling improved, and are there any remaining issues?

Could we document the current status here?

@jfurtek
Copy link
Collaborator

jfurtek commented Apr 4, 2023

XLA has some code for a scheduler that intends to improve overlapping communication and compute.

https://github.com/openxla/xla/search?p=1&q=latencyhidingscheduler&type=commits

I can't speak to its status - perhaps the developers at Google have some results. Closing...

@jfurtek jfurtek closed this as completed Apr 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR XLA
Projects
None yet
Development

No branches or pull requests

9 participants