Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing Adasum algorithm to do allreduction. #1485

Merged
merged 20 commits into from Nov 25, 2019

Conversation

@Tixxx
Copy link
Contributor

Tixxx commented Oct 29, 2019

  1. Adasum operations for both CPU and NCCL build of Horovod
  2. Framework support in Tensorflow and Pytorch to enable Adasum
  3. A new optimizer added for Tensorflow and Pytorch to deliver more accurate estimation when using Adasum

Main contributors:
Olli Saarikivi (olsaarik)
Vadim Eksarevskiy (vaeksare)
Jaliya Ekanayake (jaliyae)
Todd Mytkowicz (klipto)
Saeed Maleki(saeedmaleki)
Sergii Dymchenko(kit1980)
Tianju Xu(Tixxx)

@Tixxx

This comment has been minimized.

Copy link
Contributor Author

Tixxx commented Oct 29, 2019

Adasum

What is Adasum

Scaling DNN training to many GPUs always comes at a convergence degradation. This is because with larger batch sizes, gradients are averaged and the learning rate per example is smaller. To address this, learning rate is usually scaled up but this can lead to divergence of model parameters. Adasum addresses these two issues without introducing any hyper-parameter.

Suppose there are two almost-parallel gradients from two different GPUs, g1 and g2, and they need to be reduced as shown in the figure below. The two common practices for reductions are g1+g2, the gray vector or (g1+g2)/2, the green vector. g1+g2 may cause divergence of the model since it is effectively moving in the direction of g1 or g2 by two times the magnitude of g1 or g2. Therefore, generally (g1+g2)/2 is safer and more desired.
pic1

Now consider the two orthogonal gradients g1 and g2 in the figure below. Since g1 and g2 are in two different dimensions and independent of each other, g1+g2 may not cause divergence.
pic2

Finally, consider the third scenario where g1 and g2 are neither parallel nor orthogonal as shown in the figure below. In such a case, Adasum projects g2 on the orthogonal space of g1 (the pink vector) and adds that with g1 to produce the reduced vector. In this case, the final vector moves in each dimension only as much as each of g1 or g2 and therefore causes no divergence.
pic3

This idea extends to many gradients as well. Suppose there are 2^n gradients coming from 2^n different GPUs. Adasum inductively takes pairs of gradients and reduces them using the method above until all of them are reduced into one gradient.

Highlights of code changes

We provide an algorithmic interface which does not have any dependency on a particular communication library for extensibility. An MPI version of implementation of Adasum has been provided to support new operations we have added to Horovod. Here is the list of changes that we propose:

  • Adasum class in horovod/common/ops/adasum/adasum.h: Algorithmic interface of Adasum which contains the main logic.

  • AdasumMPI class in horovod/common/ops/adasum/adasum_mpi.h and adasum_mpi.cc: An MPI implementation of Adasum algorithm.

  • AdasumMPIAllreduceOp class in horovod/common/ops/adasum_mpi_operations.h and adasum_mpi_operations.cc: A new operation class that inherits from AdasumMPI and Horovod's AllreduceOp. This utilizes the fusion buffer to perform efficient Adasum reductions on CPU when HOROVOD_GPU_ALLREDUCE is set to None.

  • AdasumCudaAllreduceOp class in horovod/common/ops/adasum_cuda_operations.h and adasum_cuda_operations.cc: A new operation class that inherits from AdasumMPI and Horovod's NCCLAllreduce. This is a hierarchical operation that uses NCCL to perform intra-node sum-averaging and Adasum algorithm for inter-node reductions. This op requires Horovod to be compiled with HOROVOD_GPU_ALLREDUCE=NCCL

  • A new response and request type has been introduced in addition to existing ones:

enum ResponseType { ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, ADASUM = 3, ERROR = 4};

  • A new environment variable HOROVOD_ADASUM_MPI_CHUNK_SIZE has been introduced to improve MPI communication efficiency for some platform configurations(i.e. Azure NC series machines + IntelMPI).

In addition to the above changes in Horovod's common library, we also added a list of changes to framework layer for both Tensorflow and Pytorch to enable easy use of Adasum:

  • An enum that contains a list of allreduce operations has been introduced for users to select among Average, Sum or Adasum. This improves extensibility to add more ops in the future and backward compatibility.

  • An optional parameter op has been added to DistributedOptimizer and allreduce API for users to specify which operation to perform.

  • A new distributed optimizer has been added to both frameworks to support Adasum algorithm. Since the nature of Adasum requires it to operate on the full magnitude of the gradient, the newly added distributed optimizer uses the difference in magnitude of weights between before and after the optimizer performs a step to deliver a more accurate estimation. When op=hvd.Adasum is specified, the new optimizer will be used.

    DistributedOptimizer example for Tensorflow:

    opt = tf.train.AdamOptimizer(0.001)

    opt = hvd.DistributedOptimizer(opt, backward_passes_per_step=5, op=hvd.Adasum)

    Allreduce example for Tensorflow:

    hvd.allreduce(tensor, op=hvd.Adasum)

    DistributedOptimizer example for Pytorch:

    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    optimizer = hvd.DistributedOptimizer(optimizer, named_parameters=model.named_parameters(), compression=compression, backward_passes_per_step = 5, op=hvd.Adasum)

    Allreduce example for Pytorch:

    hvd.allreduce(tensor, op=hvd.Adasum)

Additional notes

  • Adasum ensures correct convergence behavior even with large effective batch sizes.

  • As the number of ranks scales up, the learning rate does not need to be scaled if using CPU to do Adasum reduction. If HOROVOD_GPU_ALLREDUCE=NCCL flag is used to compile Horovod, Adasum needs the learning rate to be scaled by the number of GPUs locally on a node.

  • Pytorch training in fp16 format is not yet supported by this pull request. We are in the process of integrating Apex into the new optimizer to enabled full mixed precision training with Adasum in Pytorch.

  • When HOROVOD_GPU_ALLREDUCE=NCCL flag is used to compile Horovod and training is run on a single node, only averaging through NCCL library is used to perform reductions and no Adasum algorithm will take place in this configuration.

@Tixxx Tixxx force-pushed the Tixxx:alpha_light_official branch 5 times, most recently from c315af2 to a0e2a48 Oct 29, 2019
Copy link
Collaborator

tgaddair left a comment

Still going through the code in detail, but wanted to give you all some quick feedback as I was getting things to compile on my local Mac laptop.

examples/pytorch_mnist.py Show resolved Hide resolved
examples/pytorch_synthetic_benchmark.py Outdated Show resolved Hide resolved
examples/pytorch_synthetic_benchmark.py Outdated Show resolved Hide resolved
horovod/common/basics.py Outdated Show resolved Hide resolved
horovod/common/operations.cc Outdated Show resolved Hide resolved
horovod/common/ops/adasum/adasum.h Outdated Show resolved Hide resolved
horovod/common/ops/adasum/adasum.h Outdated Show resolved Hide resolved
horovod/common/ops/adasum/adasum.h Outdated Show resolved Hide resolved
horovod/common/ops/adasum/adasum.h Outdated Show resolved Hide resolved
horovod/common/ops/adasum_mpi_operations.cc Outdated Show resolved Hide resolved
@Tixxx Tixxx force-pushed the Tixxx:alpha_light_official branch from 46e00f0 to 731552e Nov 4, 2019
@jaliyae

This comment has been minimized.

Copy link

jaliyae commented Nov 5, 2019

@jaliyae is added to the review. #Closed

horovod/tensorflow/__init__.py Outdated Show resolved Hide resolved
@Tixxx Tixxx closed this Nov 6, 2019
@Tixxx

This comment has been minimized.

Copy link
Contributor Author

Tixxx commented Nov 6, 2019

re-opening to trigger CI

@Tixxx Tixxx reopened this Nov 6, 2019
@Tixxx Tixxx force-pushed the Tixxx:alpha_light_official branch 2 times, most recently from afcadd7 to b2bcf91 Nov 6, 2019
@tgaddair

This comment has been minimized.

Copy link
Collaborator

tgaddair commented Nov 15, 2019

Hey @Tixxx, can you rebase off of master? There was a problem with unit tests caused by changes to tf-nightly. The tests were fixed by #1515.

Tixxx added 12 commits Oct 29, 2019
 1. Adasum operations for both CPU and NCCL build of Horovod
 2. Framework support in Tensorflow and Pytorch to enable Adasum
 3. A new optimizer added for Tensorflow and Pytorch to deliver more accurate estimation when using Adasum

Main contributors:
Olli Saarikivi (olsaarik)
Vadim Eksarevskiy (vaeksare)
Jaliya Ekanayake (jaliyae)
Todd Mytkowicz (klipto)
Saeed Maleki(saeedmaleki)
Sergii Dymchenko(kit1980)

Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
made test to be compatible with python27

Signed-off-by: Tix <tix@microsoft.com>
added adasum as an option in pytorch resnet example

Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
fixed incorrect lr scaling in examples when using Adasum
improved tf test to support tf 2.0

Signed-off-by: Tix <tix@microsoft.com>
Tixxx added 8 commits Nov 5, 2019
…alized.

dont run cpu tests if mpi is not available

Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
SKip Gloo tests for Adasum

Signed-off-by: Tix <tix@microsoft.com>
…t Horovod is compiled without GPU-ALLREDUCE flag.

Signed-off-by: Tix <tix@microsoft.com>
Fixed mxnet import failure.

Signed-off-by: Tix <tix@microsoft.com>
Signed-off-by: Tix <tix@microsoft.com>
@Tixxx Tixxx force-pushed the Tixxx:alpha_light_official branch from 88b0a0d to 08cf4e9 Nov 15, 2019
Copy link
Collaborator

tgaddair left a comment

Everything looks good from our end. Let's go ahead and merge this with plans for a few followup PRs:

  • Documentation describing Adasum (could use description from this PR) along with guidance on when to expect it to significantly improve results.
  • TensorFlow 2.0 support.
  • Keras support.
@tgaddair tgaddair merged commit 5fa1d7a into horovod:master Nov 25, 2019
2 checks passed
2 checks passed
DCO DCO
Details
buildkite/horovod/pr Build #1385 passed (51 minutes, 51 seconds)
Details
@nvcastet

This comment has been minimized.

Copy link
Contributor

nvcastet commented Nov 25, 2019

@Tixxx Are there any comparative experiment results for Adasum showing the speed and accuracy of the training for some benchmarks?

@thyeros

This comment has been minimized.

Copy link

thyeros commented Nov 26, 2019

@Tixxx does AdaSum guarantee two critical convergence properties of SGD?

  • unbiased gradient estimator
  • bounded variacne
@saeedmaleki

This comment has been minimized.

Copy link

saeedmaleki commented Nov 27, 2019

@Tixxx does AdaSum guarantee two critical convergence properties of SGD?

  • unbiased gradient estimator
  • bounded variacne

Hi @thyeros,

The expected value of the computed gradients with Adasum is not necessarily the true gradient of all training example but has a positive inner product with it. Note that the loss value decreases if the model is update with any direction that has a positive inner product with the gradient (provided that the higher order term in Taylor series are negligible).

The variance for the gradients computed with Adasum is bounded as long as the learning is scheduled properly.

@saeedmaleki

This comment has been minimized.

Copy link

saeedmaleki commented Nov 27, 2019

@Tixxx Are there any comparative experiment results for Adasum showing the speed and accuracy of the training for some benchmarks?

Hi @nvcastet,

Yes, we will provide those benchmarks and numbers in near future.

jeffdaily added a commit to ROCmSoftwarePlatform/horovod that referenced this pull request Nov 27, 2019
* Introducing Adasum algorithm to do allreduction.
 1. Adasum operations for both CPU and NCCL build of Horovod
 2. Framework support in Tensorflow and Pytorch to enable Adasum
 3. A new optimizer added for Tensorflow and Pytorch to deliver more accurate estimation when using Adasum

Main contributors:
Olli Saarikivi (olsaarik)
Vadim Eksarevskiy (vaeksare)
Jaliya Ekanayake (jaliyae)
Todd Mytkowicz (klipto)
Saeed Maleki(saeedmaleki)
Sergii Dymchenko(kit1980)

Signed-off-by: Tix <tix@microsoft.com>
@@ -96,6 +97,10 @@ def extension_available(ext_base_name, verbose=False):
return _check_extension_lambda(
ext_base_name, available_fn, 'built', verbose) or False

def gpu_available(ext_base_name, verbose=False):
available_fn = lambda ext: ext._check_has_gpu()

This comment has been minimized.

Copy link
@ConeyLiu

ConeyLiu Dec 3, 2019

Contributor

Hi, I got the following errors with horovodrun:

Checking whether extension tensorflow was running with GPU.
Traceback (most recent call last):
  File "/home/xianyang/opt/miniconda3/lib/python3.7/site-packages/horovod/common/util.py", line 73, in _target_fn
    ext = importlib.import_module('.' + ext_base_name, 'horovod')
  File "/home/xianyang/opt/miniconda3/lib/python3.7/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1006, in _gcd_import
  File "<frozen importlib._bootstrap>", line 983, in _find_and_load
  File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 728, in exec_module
  File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
  File "/home/xianyang/opt/miniconda3/lib/python3.7/site-packages/horovod/tensorflow/__init__.py", line 43, in <module>
    has_gpu = gpu_available('tensorflow')
  File "/home/xianyang/opt/miniconda3/lib/python3.7/site-packages/horovod/common/util.py", line 103, in gpu_available
    ext_base_name, available_fn, 'running with GPU', verbose) or False
  File "/home/xianyang/opt/miniconda3/lib/python3.7/site-packages/horovod/common/util.py", line 90, in _check_extension_lambda
    p.start()
  File "/home/xianyang/opt/miniconda3/lib/python3.7/multiprocessing/process.py", line 110, in start
    'daemonic processes are not allowed to have children'
AssertionError: daemonic processes are not allowed to have children
Extension tensorflow was NOT running with GPU.

This could be reproduced with:
horovod.common.util.gpu_available('tensorflow', True)

nvcastet added a commit to nvcastet/horovod that referenced this pull request Dec 10, 2019
Add extra checks for x86 compiler flags and x86 AVX headers/functions

Signed-off-by: Nicolas V Castet <nvcastet@us.ibm.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
7 participants
You can’t perform that action at this time.