Skip to content

Commit

Permalink
Add support for gradient_predivide_factor and averaging in Horovod ba…
Browse files Browse the repository at this point in the history
…ckend. (#1949)

Signed-off-by: Josh Romero <joshr@nvidia.com>
  • Loading branch information
romerojosh committed Aug 17, 2020
1 parent b83e0e2 commit e4554de
Show file tree
Hide file tree
Showing 57 changed files with 1,469 additions and 185 deletions.
1 change: 1 addition & 0 deletions Jenkinsfile.ppc64le
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pipeline {
git submodule update --init --recursive
. ${CONDA_INIT}
conda activate ${CONDA_ENV}
conda install -y cmake make
set -xe
HOROVOD_WITHOUT_MXNET=1 HOROVOD_WITHOUT_GLOO=1 HOROVOD_WITH_PYTORCH=1 HOROVOD_WITH_TENSORFLOW=1 \
HOROVOD_CUDA_HOME=$CONDA_PREFIX HOROVOD_GPU_BROADCAST=NCCL HOROVOD_GPU_ALLREDUCE=NCCL \
Expand Down
5 changes: 4 additions & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
recursive-include * *.h *.hpp *.cc *.md
recursive-include * *.h *.hpp *.cc *.cu *.md

include LICENSE horovod.lds horovod.exp
prune .eggs
Expand All @@ -19,3 +19,6 @@ exclude third_party/eigen/Eigen/src/SparseCholesky/*
graft third_party/gloo/cmake
recursive-include third_party/gloo CMakeLists.txt
recursive-include third_party/gloo *.in

# include cmake related files for CUDA compilation
include horovod/common/ops/cuda/CMakeLists.txt
1 change: 1 addition & 0 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ Possible values are given in curly brackets: {}.
* ``HOROVOD_CUDA_HOME`` - path where CUDA include and lib directories can be found.
* ``HOROVOD_CUDA_INCLUDE`` - path to CUDA include directory.
* ``HOROVOD_CUDA_LIB`` - path to CUDA lib directory.
* ``HOROVOD_BUILD_CUDA_CC_LIST`` - List of compute capabilities to build Horovod CUDA kernels for (example: ``HOROVOD_BUILD_CUDA_CC_LIST=60,70,75``)
* ``HOROVOD_ROCM_HOME`` - path where ROCm include and lib directories can be found.
* ``HOROVOD_NCCL_HOME`` - path where NCCL include and lib directories can be found.
* ``HOROVOD_NCCL_INCLUDE`` - path to NCCL include directory.
Expand Down
8 changes: 6 additions & 2 deletions examples/mxnet_imagenet_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@
help='number of batches to wait before logging (default: 0)')
parser.add_argument('--save-frequency', type=int, default=0,
help='frequency of model saving (default: 0)')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')


args = parser.parse_args()
Expand Down Expand Up @@ -322,7 +324,8 @@ def evaluate(epoch):
opt = mx.optimizer.create('sgd', **optimizer_params)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)
trainer = hvd.DistributedTrainer(params, opt,
gradient_predivide_factor=args.gradient_predivide_factor)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
Expand Down Expand Up @@ -427,7 +430,8 @@ def train_module():
opt = mx.optimizer.create('sgd', **optimizer_params)

# Horovod: wrap optimizer with DistributedOptimizer
dist_opt = hvd.DistributedOptimizer(opt)
dist_opt = hvd.DistributedOptimizer(opt,
gradient_predivide_factor=args.gradient_predivide_factor)

# Setup validation data and callback during training
eval_data = None
Expand Down
5 changes: 4 additions & 1 deletion examples/mxnet_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disable training on GPU (default: False)')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')
args = parser.parse_args()

if not args.no_cuda:
Expand Down Expand Up @@ -128,7 +130,8 @@ def evaluate(model, data_iter, context):
hvd.broadcast_parameters(params, root_rank=0)

# Horovod: create DistributedTrainer, a subclass of gluon.Trainer
trainer = hvd.DistributedTrainer(params, opt)
trainer = hvd.DistributedTrainer(params, opt,
gradient_predivide_factor=args.gradient_predivide_factor)

# Create loss function and train metric
loss_fn = gluon.loss.SoftmaxCrossEntropyLoss()
Expand Down
5 changes: 4 additions & 1 deletion examples/pytorch_imagenet_resnet50.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
'total batch size.')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')

# Default settings from https://arxiv.org/abs/1706.02677.
parser.add_argument('--batch-size', type=int, default=32,
Expand Down Expand Up @@ -272,7 +274,8 @@ def avg(self):
optimizer, named_parameters=model.named_parameters(),
compression=compression,
backward_passes_per_step=args.batches_per_allreduce,
op=hvd.Adasum if args.use_adasum else hvd.Average)
op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)

# Restore from a previous checkpoint, if initial_epoch is specified.
# Horovod: restore on the first worker which will broadcast weights to other workers.
Expand Down
5 changes: 4 additions & 1 deletion examples/pytorch_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
help='use fp16 compression during allreduce')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')


class Net(nn.Module):
Expand Down Expand Up @@ -179,7 +181,8 @@ def test():
optimizer = hvd.DistributedOptimizer(optimizer,
named_parameters=model.named_parameters(),
compression=compression,
op=hvd.Adasum if args.use_adasum else hvd.Average)
op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)

for epoch in range(1, args.epochs + 1):
train(epoch)
Expand Down
1 change: 1 addition & 0 deletions examples/tensorflow2_synthetic_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')


args = parser.parse_args()
args.cuda = not args.no_cuda

Expand Down
5 changes: 4 additions & 1 deletion examples/tensorflow_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
parser = argparse.ArgumentParser(description='Tensorflow MNIST Example')
parser.add_argument('--use-adasum', action='store_true', default=False,
help='use adasum algorithm to do reduction')
parser.add_argument('--gradient-predivide-factor', type=float, default=1.0,
help='apply gradient predivide factor in optimizer (default: 1.0)')
args = parser.parse_args()

def conv_model(feature, target, mode):
Expand Down Expand Up @@ -127,7 +129,8 @@ def main(_):
opt = tf.train.AdamOptimizer(0.001 * lr_scaler)

# Horovod: add Horovod Distributed Optimizer.
opt = hvd.DistributedOptimizer(opt, op=hvd.Adasum if args.use_adasum else hvd.Average)
opt = hvd.DistributedOptimizer(opt, op=hvd.Adasum if args.use_adasum else hvd.Average,
gradient_predivide_factor=args.gradient_predivide_factor)

global_step = tf.train.get_or_create_global_step()
train_op = opt.minimize(loss, global_step=global_step)
Expand Down
1 change: 0 additions & 1 deletion examples/tensorflow_mnist_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import tensorflow as tf
import horovod.tensorflow as hvd


def main(_):
# Horovod: initialize Horovod.
hvd.init()
Expand Down
25 changes: 21 additions & 4 deletions horovod/_keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
compression, sparse_as_dense):
compression, sparse_as_dense, gradient_predivide_factor):
class _DistributedOptimizer(keras.optimizers.Optimizer):
_HAS_AGGREGATE_GRAD = True

Expand All @@ -34,6 +34,7 @@ def __init__(self, **kwargs):
self._compression = compression
self._sparse_as_dense = sparse_as_dense
self._aggregated_gradients = False
self._gradient_predivide_factor = gradient_predivide_factor
super(self.__class__, self).__init__(**kwargs)

def get_gradients(self, loss, params):
Expand All @@ -60,6 +61,17 @@ def _aggregate_gradients(self, grads_and_vars):
def _allreduce(self, gradients):
self._aggregated_gradients = True
if hvd.size() > 1:
if self._gradient_predivide_factor != 1.0:
# Perform averaging via pre/postscaling factors.
# Split average operation across pre/postscale factors
prescale_factor = 1.0 / gradient_predivide_factor
postscale_factor = gradient_predivide_factor / hvd.size()
do_average = False
else:
prescale_factor = 1.0
postscale_factor = 1.0
do_average = True

averaged_gradients = []
with tf.name_scope(self._name + "_Allreduce"):
for grad in gradients:
Expand All @@ -68,9 +80,12 @@ def _allreduce(self, gradients):
isinstance(grad, tf.IndexedSlices):
grad = tf.convert_to_tensor(grad)
avg_grad = hvd.allreduce(grad,
average=do_average,
device_dense=self._device_dense,
device_sparse=self._device_sparse,
compression=self._compression)
compression=self._compression,
prescale_factor=prescale_factor,
postscale_factor=postscale_factor)
averaged_gradients.append(avg_grad)
else:
averaged_gradients.append(None)
Expand Down Expand Up @@ -108,8 +123,10 @@ def broadcast_global_variables(backend, root_rank):
return _eval(backend, hvd.broadcast_global_variables(root_rank))


def allreduce(backend, value, name, average):
return _eval(backend, hvd.allreduce(tf.constant(value, name=name), average=average))
def allreduce(backend, value, name, average, prescale_factor, postscale_factor):
return _eval(backend, hvd.allreduce(tf.constant(value, name=name), average=average,
prescale_factor=prescale_factor,
postscale_factor=postscale_factor))


def allgather(backend, value, name):
Expand Down
16 changes: 16 additions & 0 deletions horovod/common/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,3 +215,19 @@ def ccl_built(self):
A boolean value indicating whether oneCCL support was compiled.
"""
return bool(self.MPI_LIB_CTYPES.horovod_ccl_built())

def cuda_built(self):
"""Returns True if Horovod was compiled with CUDA support.
Returns:
A boolean value indicating whether CUDA support was compiled.
"""
return bool(self.MPI_LIB_CTYPES.horovod_cuda_built())

def rocm_built(self):
"""Returns True if Horovod was compiled with ROCm support.
Returns:
A boolean value indicating whether ROCm support was compiled.
"""
return bool(self.MPI_LIB_CTYPES.horovod_rocm_built())
38 changes: 37 additions & 1 deletion horovod/common/controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,36 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
}
}

// If we are doing an allreduce, check that prescaling and postscaling factors
// are identical across ranks.
double prescale_factor;
double postscale_factor;
if (message_type == Request::ALLREDUCE ||
message_type == Request::ADASUM) {
prescale_factor = requests[0].prescale_factor();
postscale_factor = requests[0].postscale_factor();

for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
double request_prescale_factor = requests[i].prescale_factor();
double request_postscale_factor = requests[i].postscale_factor();

if (prescale_factor != request_prescale_factor ||
postscale_factor != request_postscale_factor) {
error = true;
error_message_stream
<< "Mismatched prescale and/or postscale factors: "
<< "One rank sent factors (" << prescale_factor
<< ", " << postscale_factor << "), but another rank "
<< "sent factors (" << request_prescale_factor
<< ", " << request_postscale_factor << ").";
break;
}
}
}

std::vector<int64_t> tensor_sizes;
if (message_type == Request::ALLGATHER ||
message_type == Request::ALLTOALL) {
Expand Down Expand Up @@ -601,6 +631,8 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
response.add_tensor_size(dim);
}
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
} else if (message_type == Request::BROADCAST) {
response.set_response_type(Response::BROADCAST);
} else if (message_type == Request::ALLTOALL) {
Expand All @@ -611,6 +643,8 @@ Response Controller::ConstructResponse(std::string& name, int joined_size) {
response.add_tensor_size(dim);
}
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
}
response.set_devices(devices);

Expand Down Expand Up @@ -675,7 +709,9 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses) {
if (response.response_type() == new_response.response_type() &&
response.devices() == new_response.devices() &&
response.tensor_type() == new_response.tensor_type() &&
tensor_size + new_tensor_size <= TensorFusionThresholdBytes()) {
tensor_size + new_tensor_size <= TensorFusionThresholdBytes() &&
response.prescale_factor() == new_response.prescale_factor() &&
response.postscale_factor() == new_response.postscale_factor()) {
// These tensors will fuse together well.
tensor_size += new_tensor_size;
response.add_tensor_name(std::move(new_response.tensor_names()[0]));
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/half.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ bool is_avx_and_f16c() {
}
#endif

#if HAVE_MPI
// float16 custom data type summation operation.
void float16_sum(void* invec, void* inoutvec, int* len,
MPI_Datatype* datatype) {
Expand Down Expand Up @@ -73,6 +74,7 @@ void float16_sum(void* invec, void* inoutvec, int* len,
Float2HalfBits(&inout_float, inout + i);
}
}
#endif

} // namespace common
} // namespace horovod
12 changes: 10 additions & 2 deletions horovod/common/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,19 @@

#include <stdint.h>

#if HAVE_MPI
#define OMPI_SKIP_MPICXX
#include "mpi.h"
#endif

namespace horovod {
namespace common {

inline void HalfBits2Float(unsigned short* src, float* res) {
#if __AVX__ && __F16C__
bool is_avx_and_f16c();
#endif

inline void HalfBits2Float(const unsigned short* src, float* res) {
unsigned h = *src;
int sign = ((h >> 15) & 1);
int exp = ((h >> 10) & 0x1f);
Expand Down Expand Up @@ -70,7 +76,7 @@ inline void HalfBits2Float(unsigned short* src, float* res) {
*res = *reinterpret_cast<float const*>(&f);
}

inline void Float2HalfBits(float* src, unsigned short* dest) {
inline void Float2HalfBits(const float* src, unsigned short* dest) {
// software implementation rounds toward nearest even
unsigned const& s = *reinterpret_cast<unsigned const*>(src);
uint16_t sign = uint16_t((s >> 16) & 0x8000);
Expand Down Expand Up @@ -132,7 +138,9 @@ inline void Float2HalfBits(float* src, unsigned short* dest) {
*dest = u;
}

#if HAVE_MPI
void float16_sum(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype);
#endif

} // namespace common
} // namespace horovod
Expand Down

0 comments on commit e4554de

Please sign in to comment.