-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support overlapping NCCL collective communication with compute on GPU #4666
Comments
Currently we don't support overlapping NCCL reductions with compute. There's some code for it in XLA (using multiple CUDA streams), but it is buggy (produces wrong outputs) and not enabled. We could work on fixing it but it simply hasn't made it to the top of the priority list. However, I've usually found that doesn't matter a whole lot at single-host scales. I think the NCCL all-reduce is sticking out in the MNIST example because it doesn't actually have enough compute to keep the GPUs busy. A small MNIST model doesn't make sense to run on multiple GPUs in the first place and the example you are looking at is intended more for explanatory purposes than as something you might actually want to run. Try a more realistic model? For example, I often use the Flax mixed-precision Resnet-50 model for multi-GPU benchmarking (https://github.com/google/flax/blob/master/examples/imagenet/README.md). On an 8xV100 VM I observe that model to take around 3.5% of execution time on all-reduce operations, which while not nothing means the benefits of overlapping communication and compute would be relatively small. What do you think? |
Dear hawkinsp, I think it is a very important feature to implement a training framework for high-performance machine learning research. |
@byronyi that patch is only in nvidia's fork, any chance will they merge it into upstream tf2? |
This comment was marked as off-topic.
This comment was marked as off-topic.
I think this optimization pass solves this problem to some extent: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/async_collective_creator.cc |
Hello, to what extent is collective communication overlap scheduling improved, and are there any remaining issues? Could we document the current status here? |
XLA has some code for a scheduler that intends to improve overlapping communication and compute. https://github.com/openxla/xla/search?p=1&q=latencyhidingscheduler&type=commits I can't speak to its status - perhaps the developers at Google have some results. Closing... |
Hello, I am conducting an experiment using a server with 4 GPUs. I just run the script spmd_mnist_classifier_fromscratch.py under the example folder.
Then I used Nvidia system to profile the process, the picture below shows the timeline of GPU kernels.
It seems that the ncclAllreduce kernel is executed sequentially with other computation kernels without any overlap. Is this the real case? If so, I am wondering how to improve this.
Thanks very much.
The text was updated successfully, but these errors were encountered: