Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for gradient_predivide_factor and averaging in Horovod backend. #1949

Merged
merged 30 commits into from
Aug 17, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
69885d5
Add support for gradient_predivide_factor and averaging in Horovod ba…
romerojosh May 13, 2020
63f1025
Add files to MANIFEST.in. Fix gloo only builds.
romerojosh May 13, 2020
b8b6548
Add missing root_rank arg in MXNet code. Change to default -1 instead…
romerojosh May 13, 2020
be8d15f
Revert use of MPI_IN_PLACE in ccl_operations. Add horovod_cuda_lib to…
romerojosh May 13, 2020
3bf34f0
Compile half.cc for gloo builds.
romerojosh May 14, 2020
394b167
Fixes to torch v1 build. Build Horovod CUDA kernels based on framewor…
romerojosh May 14, 2020
445cf63
Move postscale_factor modification for average into backend.
romerojosh May 14, 2020
1cf2b3f
Extend backend averaging to Adasum. Fix to ppc64le build.
romerojosh May 15, 2020
0daea1a
Extend gradient_predivide_factor support to Keras.
romerojosh May 15, 2020
116e138
Fix ppc64le build.
romerojosh May 15, 2020
788eef8
Add gradient_predivide_factor to examples. Fix torch optimizer.py.
romerojosh May 15, 2020
96b9aa7
Raise exception if op != Average and gradient_predivide_factor is set…
romerojosh Jul 1, 2020
5217244
Remove gradient_predivide_factor arg from some examples.
romerojosh Jul 1, 2020
d641cd9
Document HOROVOD_BUILD_CUDA_CC_LIST env var.
romerojosh Jul 1, 2020
346ee50
Cleanup/fixes after rebase.
romerojosh Jul 21, 2020
adbc290
Update TF allreduce gradient to include pre/postscale factors.
romerojosh Jul 21, 2020
c59c324
Convert prescale and postscale factor args to scalar tensors to maint…
romerojosh Aug 5, 2020
e7f07c5
More robust testing of prescale and postscale factor behavior.
romerojosh Aug 5, 2020
cb1efdf
Fixes after rebase.
romerojosh Aug 5, 2020
369cf5c
Use size_op() to compute postscale_factor.
romerojosh Aug 5, 2020
db105fb
Revert "Use size_op() to compute postscale_factor."
romerojosh Aug 6, 2020
9ce7caa
Revert "Convert prescale and postscale factor args to scalar tensors …
romerojosh Aug 7, 2020
b9437cd
Skip FP64 prescaling/postscaling tests for TensorFlow.
romerojosh Aug 7, 2020
04c4d57
Remove size() usage in Python when computing scaling factors for TF f…
romerojosh Aug 7, 2020
4442229
Fix __CUDA_ARCH__ usage so half2 specialized kernel is invoked on sup…
romerojosh Aug 7, 2020
4213b43
Fix up pre/postscale torch tests for torch 1.12 multiplication behavior.
romerojosh Aug 7, 2020
38363ce
Update supported compute capability detection.
romerojosh Aug 9, 2020
1247043
Fix pre/postscaling tests for MXNet 1.4.
romerojosh Aug 9, 2020
58af1c1
Update pre/postscale tests. Deal with HOROVOD_MIXED_INSTALL cases.
romerojosh Aug 9, 2020
f0bcf58
Fix pre/postscale test for PyTorch HOROVOD_MIXED_INSTALL case.
romerojosh Aug 10, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions Jenkinsfile.ppc64le
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pipeline {
git submodule update --init --recursive
. ${CONDA_INIT}
conda activate ${CONDA_ENV}
conda install -y cmake make
set -xe
HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITHOUT_GLOO=1 HOROVOD_WITH_PYTORCH=1 HOROVOD_WITH_TENSORFLOW=1 \
HOROVOD_CUDA_HOME=$CONDA_PREFIX HOROVOD_GPU_BROADCAST=NCCL HOROVOD_GPU_ALLREDUCE=NCCL \
Expand Down
5 changes: 4 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
recursive-include * *.h *.hpp *.cc *.md
recursive-include * *.h *.hpp *.cc *.cu *.md

include LICENSE horovod.lds horovod.exp
prune .eggs
Expand All @@ -19,3 +19,6 @@ exclude third_party/eigen/Eigen/src/SparseCholesky/*
graft third_party/gloo/cmake
recursive-include third_party/gloo CMakeLists.txt
recursive-include third_party/gloo *.in

# include cmake related files for CUDA compilation
include horovod/common/ops/cuda/CMakeLists.txt
1 change: 1 addition & 0 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ Possible values are given in curly brackets: {}.
* ``HOROVOD_CUDA_HOME`` - path where CUDA include and lib directories can be found.
* ``HOROVOD_CUDA_INCLUDE`` - path to CUDA include directory.
* ``HOROVOD_CUDA_LIB`` - path to CUDA lib directory.
* ``HOROVOD_BUILD_CUDA_CC_LIST`` - List of compute capabilities to build Horovod CUDA kernels for (example: ``HOROVOD_BUILD_CUDA_CC_LIST=60,70,75``)
* ``HOROVOD_ROCM_HOME`` - path where ROCm include and lib directories can be found.
* ``HOROVOD_NCCL_HOME`` - path where NCCL include and lib directories can be found.
* ``HOROVOD_NCCL_INCLUDE`` - path to NCCL include directory.
Expand Down
8 changes: 6 additions & 2 deletions examples/mxnet_imagenet_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
help='number of batches to wait before logging (default: 0)')
parser.add_argument('--save-frequency', type=int, default=0,
help='frequency of model saving (default: 0)')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')


args = parser.parse_args()
Expand Down Expand Up @@ -322,7 +324,8 @@ def evaluate(epoch):
opt = mx.optimizer.create('sgd', **optimizer_params)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)
trainer = hvd.DistributedTrainer(params, opt,
tgaddair marked this conversation as resolved.
Show resolved Hide resolved
gradient_predivide_factor=args.gradient_predivide_factor)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
Expand Down Expand Up @@ -427,7 +430,8 @@ def train_module():
opt = mx.optimizer.create('sgd', **optimizer_params)

# Horovod: wrap optimizer with DistributedOptimizer
dist_opt = hvd.DistributedOptimizer(opt)
dist_opt = hvd.DistributedOptimizer(opt,
gradient_predivide_factor=args.gradient_predivide_factor)

# Setup validation data and callback during training
eval_data = None
Expand Down
5 changes: 4 additions & 1 deletion examples/mxnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disable training on GPU (default: False)')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')
args = parser.parse_args()

if not args.no_cuda:
Expand Down Expand Up @@ -128,7 +130,8 @@ def evaluate(model, data_iter, context):
hvd.broadcast_parameters(params, root_rank=0)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)
trainer = hvd.DistributedTrainer(params, opt,
gradient_predivide_factor=args.gradient_predivide_factor)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
Expand Down
5 changes: 4 additions & 1 deletion examples/pytorch_imagenet_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
'total batch size.')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
Expand Down Expand Up @@ -272,7 +274,8 @@ def avg(self):
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce,
op=hvd.Adasum if args.use_adasum else hvd.Average)
op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
Expand Down
5 changes: 4 additions & 1 deletion examples/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
help='use fp16 compression during allreduce')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')


class Net(nn.Module):
Expand Down Expand Up @@ -179,7 +181,8 @@ def test():
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)
op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)

for epoch in range(1, args.epochs + 1):
train(epoch)
Expand Down
1 change: 1 addition & 0 deletions examples/tensorflow2_synthetic_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')


args = parser.parse_args()
args.cuda = not args.no_cuda

Expand Down
5 changes: 4 additions & 1 deletion examples/tensorflow_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
parser = argparse.ArgumentParser(description='Tensorflow MNIST Example')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')
args = parser.parse_args()

def conv_model(feature, target, mode):
Expand Down Expand Up @@ -127,7 +129,8 @@ def main(_):
opt = tf.train.AdamOptimizer(0.001 * lr_scaler)

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt, op=hvd.Adasum if args.use_adasum else hvd.Average)
opt = hvd.DistributedOptimizer(opt, op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)

global_step = tf.train.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)
Expand Down
1 change: 0 additions & 1 deletion examples/tensorflow_mnist_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tensorflow as tf
import horovod.tensorflow as hvd


def main(_):
# Horovod: initialize Horovod.
hvd.init()
Expand Down
25 changes: 21 additions & 4 deletions horovod/_keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
compression, sparse_as_dense):
compression, sparse_as_dense, gradient_predivide_factor):
class _DistributedOptimizer(keras.optimizers.Optimizer):
_HAS_AGGREGATE_GRAD = True

Expand All @@ -34,6 +34,7 @@ def __init__(self, **kwargs):
self._compression = compression
self._sparse_as_dense = sparse_as_dense
self._aggregated_gradients = False
self._gradient_predivide_factor = gradient_predivide_factor
super(self.__class__, self).__init__(**kwargs)

def get_gradients(self, loss, params):
Expand All @@ -60,6 +61,17 @@ def _aggregate_gradients(self, grads_and_vars):
def _allreduce(self, gradients):
self._aggregated_gradients = True
if hvd.size() > 1:
if self._gradient_predivide_factor != 1.0:
# Perform averaging via pre/postscaling factors.
# Split average operation across pre/postscale factors
prescale_factor = 1.0 / gradient_predivide_factor
postscale_factor = gradient_predivide_factor / hvd.size()
do_average = False
else:
prescale_factor = 1.0
postscale_factor = 1.0
do_average = True

averaged_gradients = []
with tf.name_scope(self._name + "_Allreduce"):
for grad in gradients:
Expand All @@ -68,9 +80,12 @@ def _allreduce(self, gradients):
isinstance(grad, tf.IndexedSlices):
grad = tf.convert_to_tensor(grad)
avg_grad = hvd.allreduce(grad,
average=do_average,
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
compression=self._compression,
prescale_factor=prescale_factor,
postscale_factor=postscale_factor)
averaged_gradients.append(avg_grad)
else:
averaged_gradients.append(None)
Expand Down Expand Up @@ -108,8 +123,10 @@ def broadcast_global_variables(backend, root_rank):
return _eval(backend, hvd.broadcast_global_variables(root_rank))


def allreduce(backend, value, name, average):
return _eval(backend, hvd.allreduce(tf.constant(value, name=name), average=average))
def allreduce(backend, value, name, average, prescale_factor, postscale_factor):
return _eval(backend, hvd.allreduce(tf.constant(value, name=name), average=average,
prescale_factor=prescale_factor,
postscale_factor=postscale_factor))


def allgather(backend, value, name):
Expand Down
16 changes: 16 additions & 0 deletions horovod/common/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,19 @@ def ccl_built(self):
A boolean value indicating whether oneCCL support was compiled.
"""
return bool(self.MPI_LIB_CTYPES.horovod_ccl_built())

def cuda_built(self):
"""Returns True if Horovod was compiled with CUDA support.

Returns:
A boolean value indicating whether CUDA support was compiled.
"""
return bool(self.MPI_LIB_CTYPES.horovod_cuda_built())

def rocm_built(self):
"""Returns True if Horovod was compiled with ROCm support.

Returns:
A boolean value indicating whether ROCm support was compiled.
"""
return bool(self.MPI_LIB_CTYPES.horovod_rocm_built())
38 changes: 37 additions & 1 deletion horovod/common/controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,36 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
}
}

// If we are doing an allreduce, check that prescaling and postscaling factors
// are identical across ranks.
double prescale_factor;
double postscale_factor;
if (message_type == Request::ALLREDUCE ||
message_type == Request::ADASUM) {
prescale_factor = requests[0].prescale_factor();
postscale_factor = requests[0].postscale_factor();

for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
tgaddair marked this conversation as resolved.
Show resolved Hide resolved
break;
}
double request_prescale_factor = requests[i].prescale_factor();
double request_postscale_factor = requests[i].postscale_factor();

if (prescale_factor != request_prescale_factor ||
postscale_factor != request_postscale_factor) {
error = true;
error_message_stream
<< "Mismatched prescale and/or postscale factors: "
<< "One rank sent factors (" << prescale_factor
<< ", " << postscale_factor << "), but another rank "
<< "sent factors (" << request_prescale_factor
<< ", " << request_postscale_factor << ").";
break;
}
}
}

std::vector<int64_t> tensor_sizes;
if (message_type == Request::ALLGATHER ||
message_type == Request::ALLTOALL) {
Expand Down Expand Up @@ -601,6 +631,8 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
response.add_tensor_size(dim);
}
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
} else if (message_type == Request::BROADCAST) {
response.set_response_type(Response::BROADCAST);
} else if (message_type == Request::ALLTOALL) {
Expand All @@ -611,6 +643,8 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
response.add_tensor_size(dim);
}
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
}
response.set_devices(devices);

Expand Down Expand Up @@ -675,7 +709,9 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
if (response.response_type() == new_response.response_type() &&
response.devices() == new_response.devices() &&
response.tensor_type() == new_response.tensor_type() &&
tensor_size + new_tensor_size <= TensorFusionThresholdBytes()) {
tensor_size + new_tensor_size <= TensorFusionThresholdBytes() &&
response.prescale_factor() == new_response.prescale_factor() &&
response.postscale_factor() == new_response.postscale_factor()) {
// These tensors will fuse together well.
tensor_size += new_tensor_size;
response.add_tensor_name(std::move(new_response.tensor_names()[0]));
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/half.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ bool is_avx_and_f16c() {
}
#endif

#if HAVE_MPI
tgaddair marked this conversation as resolved.
Show resolved Hide resolved
// float16 custom data type summation operation.
void float16_sum(void* invec, void* inoutvec, int* len,
MPI_Datatype* datatype) {
Expand Down Expand Up @@ -73,6 +74,7 @@ void float16_sum(void* invec, void* inoutvec, int* len,
Float2HalfBits(&inout_float, inout + i);
}
}
#endif

} // namespace common
} // namespace horovod
12 changes: 10 additions & 2 deletions horovod/common/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@

#include <stdint.h>

#if HAVE_MPI
#define OMPI_SKIP_MPICXX
#include "mpi.h"
#endif

namespace horovod {
namespace common {

inline void HalfBits2Float(unsigned short* src, float* res) {
#if __AVX__ && __F16C__
bool is_avx_and_f16c();
#endif

inline void HalfBits2Float(const unsigned short* src, float* res) {
unsigned h = *src;
int sign = ((h >> 15) & 1);
int exp = ((h >> 10) & 0x1f);
Expand Down Expand Up @@ -70,7 +76,7 @@ inline void HalfBits2Float(unsigned short* src, float* res) {
*res = *reinterpret_cast<float const*>(&f);
}

inline void Float2HalfBits(float* src, unsigned short* dest) {
inline void Float2HalfBits(const float* src, unsigned short* dest) {
// software implementation rounds toward nearest even
unsigned const& s = *reinterpret_cast<unsigned const*>(src);
uint16_t sign = uint16_t((s >> 16) & 0x8000);
Expand Down Expand Up @@ -132,7 +138,9 @@ inline void Float2HalfBits(float* src, unsigned short* dest) {
*dest = u;
}

#if HAVE_MPI
void float16_sum(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype);
#endif

} // namespace common
} // namespace horovod
Expand Down