Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adasum Full GPU Ring-based Allreduce #1760

Closed
wants to merge 2 commits into from

Conversation

vaeksare
Copy link

@vaeksare vaeksare commented Mar 2, 2020

This PR adds a new Adasum Op that is capable of being performed completely on the GPU intranode. It works by mimicking the NCCL ring Allreduce using a custom algorithm that is based on CUDA-aware MPI send/receive primitives. On machines that are able to support it, it allows for much higher throughput than other Adasum operations.

Note that this current version only works on DGX1-like machines that have 8 GPUs. It could be extended to more in the future, but this was the main use case where the other two currently existing Adasum modes (CPU and Hierarchical) both have large drawbacks.

In order to use it, Horovod should be compiled with HOROVOD_GPU_ALLREDUCE=MPI and op=hvd.Adasum should be passed in to the optimizer/allreduce calls. The user needs to have a CUDA-aware MPI implementation installed (we used OpenMPI with UCX).

@nvcastet
Copy link
Collaborator

nvcastet commented Mar 4, 2020

@vaeksare Thanks for the PR!
What limits this new op to DGX1-like machines?
Could you share some performance numbers between the Adasum ops and also compared to a regular NCCL allreduce op?

@vaeksare
Copy link
Author

vaeksare commented Mar 5, 2020

@nvcastet thanks for taking a look!
The ring building algorithm using for the Allreduce is specialized for a DGX1 in terms of number of GPUs per ring and number of rings used. It will always build 4 rings currently, 2 "fat" rings and 2 "skinny" rings with half the capacity. This is only ideal on an 8-GPU hybrid cube-mesh interconnection that a DGX1 has, in which each GPU is connected via NVLink to 4 other GPUs, and some of the NVLinks have double the effective bandwidth of others. Additionally, the rings are hardcoded to have 8 GPUs in each. In order to support varying configurations, we would need to to have a new algorithm to dynamically build the rings based on the exact configuration. Right now, the rings are built dynamically but there are always 4 of them and it's assumed the NVLink configuration follow the restrains mentioned above.

In terms of rough performance numbers, running pytorch_synthetic_benchmark.py the throughput results are approximately the following (all tested on a single DGX1):

NCCL Allreduce: 310 images/sec/gpu (all averaging)
Ring-based Adasum: 290 images/sec/gpu (all Adasum operation)
Hierarchical "Adasum": 310 images/sec/gpu (note that because this is single node, this will actually just run NCCL Averaging since it only does Adasum cross-node. So this carries none of the Adasum convergence benefits)
CPU Adasum: 100 images/sec/gpu (all Adasum operation. This is so slow because it's not able to utilize GPU P2P communication through NVLinks).

@vaeksare
Copy link
Author

vaeksare commented Mar 5, 2020

It also seems like the CI builds are currently failing with a rather strange failure (not being able to find cmake I think?) Do you happen to have any insights into why that could be the case? Builds fine locally for me.

@nvcastet
Copy link
Collaborator

nvcastet commented Mar 9, 2020

@vaeksare You may want to rebase your branch. We fixed some CI build issues recently.
@vaeksare All the logic to calculate how many rings to better utilize the hardware topology is already nicely done inside NCCL. From what I understood, you would need ncclallreduce to support a custom kernel for the reduction op instead of just avg/sum/min/max to implement your algorithm?
@sjeaugey @romerojosh Do you know if there is a way to customize the reduce op of a ncclallreduce?

@sjeaugey
Copy link

sjeaugey commented Mar 9, 2020

The reduction operation is at the very core of the NCCL CUDA kernels, and we generate kernels for each operation/datatype, because we need to properly unroll and merge it with the rest of the algorithm. We looked at ways to use user-defined operations but it was either hard to use or very slow.

As far as I understand the goal here is to scale values down as we sum, to avoid reaching the limits of the format, and end up with the average instead of the sum. Is that right ?

@vaeksare
Copy link
Author

vaeksare commented Mar 9, 2020

@sjeaugey @nvcastet Scaling down is not quite accurate, what we do is actually quite a bit more complex. Effectively instead of summing the two vectors, they are projected onto each other (through a dot product and norm calculations). This will effectively interpolate between a sum and an average depending on how parallel or orthogonal the vectors are. For more details, you can see here: #1485.

The big issue with this is that this operation is not associative. We have looked quite closely as NCCL internals before, and had some discussions, and don't believe there is any reasonable way to implement this in existing NCCL operations, as the way the Allreduce is coded in NCCL it cannot support a non-associative custom op without losing all of its performance.

It is definitely true that we could write ring formation based on how NCCL does it, but I believe this would still require largely rewriting that portion of NCCL inside our Adasum ops, which was outside of the scope of this immediate work.

I will rebase shortly for the CI.

Copy link
Collaborator

@romerojosh romerojosh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @vaeksare! I just took a first rough pass on the code and left a few comments to start out.


find_package(CUDA)

list(APPEND CUDA_NVCC_FLAGS "--compiler-options -fPIC -D_FORCE_INLINES -arch=sm_60")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we generalize these flags to build for more than just sm_60? Might be good to expose this also as a build option so a user can specify their own build flags.


template<typename T, typename TACC>
__global__
void CudaDotProductKernel(int count, const T* a, const T* b, TACC* out) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use functions from cuBLAS for this dot product operation like cublasDotEx or similarly cublasNrm2Ex? :
https://docs.nvidia.com/cuda/cublas/index.html#cublas-dotEx
https://docs.nvidia.com/cuda/cublas/index.html#cublas-nrm2Ex

The use of atomicAdd in this kernel will introduce non-determinism (which may be a concern) while these APIs are deterministic. There are also Thrust APIs for reductions and inner products but the docs suggest those are also non-deterministic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using the cuBLAS operations for this previously, but found that the performance was actually a lot lower with those than writing this directly (~20-25% slower). I don't believe that it's worth the trade off.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect using individual cublas calls to be slower than the fused approach you have, but the benefit is determinism. Is the 20-25% slower end-to-end, or just timing this particular set of operations. @tgaddair do you have any comments on this regarding deterministic operations?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, my preference is to prefer performance over determinism by default, but provide a means for users to enforce stricter determinism guarantees if desired. So we could toggle between the fused and unfused ops with an environment variable or horovodrun arg, for example. I feel this would be consistent with the way we handle similar cases (e.g., tensor fusion). Does that sound reasonable to you all?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that sounds reasonable. A second deterministic version of this fused kernel approach can be written. Instead of atomic adds, the first kernel should write intermediate block sums to a temporary array. Then a second kernel can be launched to sum the temporary results in the array.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually had this implemented exactly as you described originally (I think the part below is the leftover portion from me forgetting to remove all that code). The 20-25% slowdown I observed with it was on the pytorch synthetic benchmark example (resnet50 with batch size 32). So it was a significant slowdown in an end-to-end scenario, albeit a very communication bound one. I am not sure how significant of a concern this non determinism is, but I could add back the deterministic path. I mostly removed it because I think the slowdown is too much for any real use case, and it adds a rather large amount of code.

horovod/common/ops/cuda/adasum_cuda_kernels.cu Outdated Show resolved Hide resolved
warnings.warn('Adasum reduction does not currently support GPU reduction using MPI. Tensors '
'are copied to CPU memory instead. To use Adasum for GPU reduction, please '
'compile Horovod with HOROVOD_GPU_ALLREDUCE=NCCL.')
if horovod_local_size != 8:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this is limited to exactly 8 GPUs? It seems like most of the intra-GPU communication is carried out via MPI_Isend/MPI_Irecv and thus should generalize to any number of GPUs? I think there would be less performance the GPUs aren't P2P connected, but there isn't a guarantee that all 8 GPU systems would have a DGX1-like topology.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No real good reason. Primary reason for this was that this is only intended to be used by DGX1-like architectures, but checking for number of GPUs is much simpler than checking for the exact architecture. So this seemed better than no check at all, but much cleaner than adding convoluted checks for the exact topology.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest having a default (but maybe slow path) that works on any number of GPUs. You can print a warning that specifies that the current design is optimized for DGX1 like topologies with 8 GPUs.

horovod/common/ops/cuda/adasum_cuda_kernels.cu Outdated Show resolved Hide resolved

template<typename T, typename TACC>
__global__
void CudaScaleAddKernel(int count, T* a, const T* b, TACC a_coeff, TACC b_coeff) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to above, could this operation be expressed using Thrust using thrust::transform?: https://thrust.github.io/doc/group__transformations_ga68a3ba7d332887f1332ca3bc04453792.html#ga68a3ba7d332887f1332ca3bc04453792

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked again and I do not think this kernel is used either, so it should be removed.


template<typename T, typename TACC>
__global__
void CudaSingleAdasumKernel(int count, T* a, const T* b, TACC* out) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vaeksare
Copy link
Author

Hi @romerojosh , thanks for the initial feedback! Regarding using cuBLAS and thrust, I experimented with it but found the performance to be generally worse than writing the operations directly. And given that the operations don't add too much code, I think it's worth it to just have them be implemented manually.
Seems like rebasing did not fix the build error. Let me investigate further.

@stale
Copy link

stale bot commented Nov 6, 2020

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

@stale stale bot added the wontfix label Nov 6, 2020
@stale stale bot closed this Nov 13, 2020
@liuyunfeng2016
Copy link

@ashbhandare @tgaddair @sblotner @vaeksare @romerojosh @sjeaugey
hello,I found that hvd.AdaSum is incorrect in Horovod 0.21.0. It should be hvd.Adasum.
Then I have some questions when I use Adasum,how do I choose Pure CPU、Ring、Hierarchical、three modes
Now,I found HOROVOD_HIERARCHICAL_ALLREDUCE in paper,
my environment is 8 Tesla V100s per node.How can I make the most of his ability?I found The performance deteriorates by 10 times when the Adasum is used.
The following are my building parameters.

ENV HOROVOD_WITHOUT_GLOO=1
ENV HOROVOD_GPU_OPERATIONS='NCCL'
ENV HOROVOD_CPU_OPERATIONS='MPI'
ENV HOROVOD_WITH_PYTORCH=1
ENV HOROVOD_NCCL_INCLUDE=/usr/include
ENV HOROVOD_NCCL_LIB=/usr/lib/x86_64-linux-gnu
ENV HOROVOD_MPICXX_SHOW="/usr/local/openmpi/bin/mpicxx -show"
ENV HOROVOD_WITH_TENSORFLOW=1
ENV HOROVOD_CUDA_HOME="/usr/local/cuda"
ENV TENSORFLOW_VERSION=1.15

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging this pull request may close these issues.

None yet

6 participants