Skip to content

Commit

Permalink
Merge branch 'horovod:master' into spark-torch-gradient-accumulation
Browse files Browse the repository at this point in the history
Signed-off-by: Li Jiang <bnujli@gmail.com>
  • Loading branch information
thinkall committed Sep 13, 2022
2 parents ac55b99 + 25ed803 commit e6133e3
Show file tree
Hide file tree
Showing 12 changed files with 116 additions and 54 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Added `register_local_source` and `use_generic_names` funtionality to DistributedGradientTape. ([#3628](https://github.com/horovod/horovod/pull/3628))
- Added `transformation_edit_fields` and `transformation_removed_fields` param for EstimatorParams. ([#3651](https://github.com/horovod/horovod/pull/3651))
- Added `PartialDistributedGradientTape()` API for model parallel use cases. ([#3643](https://github.com/horovod/horovod/pull/3643))
- Enable use of native `ncclAvg` op for NCCL allreduces. ([#3646](https://github.com/horovod/horovod/pull/3646))

### Changed

Expand All @@ -40,6 +41,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- PyTorch: Fixed Reducescatter functions to raise `HorovodInternalError` rather than `RuntimeError`. ([#3594](https://github.com/horovod/horovod/pull/3594))
- PyTorch on GPUs without GPU operations: Fixed grouped allreduce to set CPU device in tensor table. ([#3594](https://github.com/horovod/horovod/pull/3594))
- Fixed race condition in PyTorch allocation handling. ([#3639](https://github.com/horovod/horovod/pull/3639))
- Build: Fixed finding nvcc (if not in $PATH) with older versions of CMake. ([#3682](https://github.com/horovod/horovod/pull/3682))


## [v0.25.0] - 2022-06-20
Expand Down
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ if(NOT CMAKE_CUDA_COMPILER)
find_package(CUDAToolkit)
if(CUDAToolkit_BIN_DIR)
message("CUDA compiler was not found in $PATH, but searching again in CUDA Toolkit binary directory")
unset(CMAKE_CUDA_COMPILER CACHE) # need to clear this from cache, else some versions of CMake go into an infinite loop
set(CMAKE_CUDA_COMPILER "${CUDAToolkit_BIN_DIR}/nvcc")
check_language(CUDA)
endif()
Expand Down
6 changes: 4 additions & 2 deletions Dockerfile.test.cpu
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,10 @@ RUN pip install --no-cache-dir ray==1.7.0

# Install MPI.
RUN if [[ ${MPI_KIND} == "OpenMPI" ]]; then \
wget --progress=dot:mega -O /tmp/openmpi-3.0.0-bin.tar.gz https://github.com/horovod/horovod/files/1596799/openmpi-3.0.0-bin.tar.gz && \
cd /usr/local && tar -zxf /tmp/openmpi-3.0.0-bin.tar.gz && ldconfig && \
wget --progress=dot:mega -O /tmp/openmpi-4.1.4-bin.tar.gz https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.4.tar.gz && \
cd /tmp && tar -zxf /tmp/openmpi-4.1.4-bin.tar.gz && \
mkdir openmpi-4.1.4/build && cd openmpi-4.1.4/build && ../configure --prefix=/usr/local && \
make -j all && make install && ldconfig && \
echo "mpirun -allow-run-as-root -np 2 -H localhost:2 -bind-to none -map-by slot -mca mpi_abort_print_stack 1" > /mpirun_command; \
elif [[ ${MPI_KIND} == "ONECCL" ]]; then \
wget --progress=dot:mega -O /tmp/oneccl.tar.gz https://github.com/oneapi-src/oneCCL/archive/${CCL_PACKAGE}.tar.gz && \
Expand Down
6 changes: 4 additions & 2 deletions Dockerfile.test.gpu
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,10 @@ RUN pip install --no-cache-dir ray==1.3.0

# Install MPI.
RUN if [[ ${MPI_KIND} == "OpenMPI" ]]; then \
wget --progress=dot:mega -O /tmp/openmpi-3.0.0-bin.tar.gz https://github.com/horovod/horovod/files/1596799/openmpi-3.0.0-bin.tar.gz && \
cd /usr/local && tar -zxf /tmp/openmpi-3.0.0-bin.tar.gz && ldconfig && \
wget --progress=dot:mega -O /tmp/openmpi-4.1.4-bin.tar.gz https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.4.tar.gz && \
cd /tmp && tar -zxf /tmp/openmpi-4.1.4-bin.tar.gz && \
mkdir openmpi-4.1.4/build && cd openmpi-4.1.4/build && ../configure --prefix=/usr/local && \
make -j all && make install && ldconfig && \
echo "mpirun -allow-run-as-root -np 2 -H localhost:2 -bind-to none -map-by slot -mca mpi_abort_print_stack 1" > /mpirun_command; \
elif [[ ${MPI_KIND} == "MPICH" ]]; then \
apt-get update -qq && apt-get install -y mpich && \
Expand Down
8 changes: 4 additions & 4 deletions docker/horovod-cpu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-
&& apt-get clean && rm -rf /var/lib/apt/lists/*

# Install Open MPI
RUN wget --progress=dot:mega -O /tmp/openmpi-3.0.0-bin.tar.gz https://github.com/horovod/horovod/files/1596799/openmpi-3.0.0-bin.tar.gz && \
cd /usr/local && \
tar -zxf /tmp/openmpi-3.0.0-bin.tar.gz && \
ldconfig && \
RUN wget --progress=dot:mega -O /tmp/openmpi-4.1.4-bin.tar.gz https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.4.tar.gz && \
cd /tmp && tar -zxf /tmp/openmpi-4.1.4-bin.tar.gz && \
mkdir openmpi-4.1.4/build && cd openmpi-4.1.4/build && ../configure --prefix=/usr/local && \
make -j all && make install && ldconfig && \
mpirun --version

# Allow OpenSSH to talk to containers without asking for confirmation
Expand Down
9 changes: 2 additions & 7 deletions docker/horovod-nvtabular/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,9 @@ RUN if [[ ${MPI_KIND} != "None" ]]; then \

# Install TensorFlow and Keras (releases).
# Pin scipy!=1.4.0: https://github.com/scipy/scipy/issues/11237
# Pin protobuf~=3.20 for tensorflow<2.6.5: https://github.com/tensorflow/tensorflow/issues/56077
# Pin protobuf<4 for tensorflow: https://github.com/tensorflow/tensorflow/issues/56815
RUN if [[ ${TENSORFLOW_PACKAGE} != "tf-nightly-gpu" ]]; then \
PROTOBUF_PACKAGE=""; \
if [[ ${TENSORFLOW_PACKAGE} == tensorflow-gpu==1.15.* ]] || \
[[ ${TENSORFLOW_PACKAGE} == tensorflow-gpu==2.[012345].* ]]; then \
PROTOBUF_PACKAGE="protobuf~=3.20"; \
fi; \
pip install --no-cache-dir ${TENSORFLOW_PACKAGE} ${PROTOBUF_PACKAGE}; \
pip install --no-cache-dir ${TENSORFLOW_PACKAGE} "protobuf<4"; \
if [[ ${KERAS_PACKAGE} != "None" ]]; then \
pip uninstall -y keras; \
pip install --no-cache-dir ${KERAS_PACKAGE} "scipy!=1.4.0" "pandas<1.1.0"; \
Expand Down
8 changes: 4 additions & 4 deletions docker/horovod/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ RUN apt-get update && apt-get install -y --allow-downgrades --allow-change-held-
&& apt-get clean && rm -rf /var/lib/apt/lists/*

# Install Open MPI
RUN wget --progress=dot:mega -O /tmp/openmpi-3.0.0-bin.tar.gz https://github.com/horovod/horovod/files/1596799/openmpi-3.0.0-bin.tar.gz && \
cd /usr/local && \
tar -zxf /tmp/openmpi-3.0.0-bin.tar.gz && \
ldconfig && \
RUN wget --progress=dot:mega -O /tmp/openmpi-4.1.4-bin.tar.gz https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-4.1.4.tar.gz && \
cd /tmp && tar -zxf /tmp/openmpi-4.1.4-bin.tar.gz && \
mkdir openmpi-4.1.4/build && cd openmpi-4.1.4/build && ../configure --prefix=/usr/local && \
make -j all && make install && ldconfig && \
mpirun --version

# Allow OpenSSH to talk to containers without asking for confirmation
Expand Down
6 changes: 3 additions & 3 deletions docs/troubleshooting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,9 @@ We recommended reinstalling Open MPI with the ``--enable-orterun-prefix-by-defau
.. code-block:: bash
$ wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.0.0.tar.gz
$ tar zxf openmpi-4.0.0.tar.gz
$ cd openmpi-4.0.0
$ wget https://www.open-mpi.org/software/ompi/v4.0/downloads/openmpi-4.1.4.tar.gz
$ tar zxf openmpi-4.1.4.tar.gz
$ cd openmpi-4.1.4
$ ./configure --enable-orterun-prefix-by-default
$ make -j $(nproc) all
$ make install
Expand Down
28 changes: 21 additions & 7 deletions horovod/common/ops/nccl_operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,24 @@ Status NCCLAllreduce::Execute(std::vector<TensorTableEntry>& entries,
void* buffer_data;
size_t buffer_len;

ncclRedOp_t ncclOp = ncclSum;
double prescale_factor = response.prescale_factor();
double postscale_factor = response.postscale_factor();
#ifdef NCCL_AVG_SUPPORTED
auto& process_set =
global_state_->process_set_table.Get(entries[0].process_set_id);
if (prescale_factor == 1.0 &&
postscale_factor == 1.0 / process_set.controller->GetSize()) {
// Use NCCLAvg op in place of postscale_factor
ncclOp = ncclAvg;
postscale_factor = 1.0;
}
#endif

// Copy (and possibly scale) tensors into the fusion buffer.
if (entries.size() > 1) {
ScaleMemcpyInFusionBuffer(entries, fused_input_data, buffer_data,
buffer_len, response.prescale_factor());
buffer_len, prescale_factor);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue,
MEMCPY_IN_FUSION_BUFFER,
Expand All @@ -204,9 +218,9 @@ Status NCCLAllreduce::Execute(std::vector<TensorTableEntry>& entries,
buffer_len = (size_t)first_entry.output->size();
int64_t num_elements =
buffer_len / DataType_Size(first_entry.tensor->dtype());
if (response.prescale_factor() != 1.0) {
if (prescale_factor != 1.0) {
// Execute prescaling op
ScaleBuffer(response.prescale_factor(), entries, fused_input_data,
ScaleBuffer(prescale_factor, entries, fused_input_data,
buffer_data, num_elements);
fused_input_data = buffer_data; // for unfused, scale is done out of place
}
Expand All @@ -217,7 +231,7 @@ Status NCCLAllreduce::Execute(std::vector<TensorTableEntry>& entries,
buffer_len / DataType_Size(first_entry.tensor->dtype());
auto nccl_result =
ncclAllReduce(fused_input_data, buffer_data, (size_t)num_elements,
GetNCCLDataType(first_entry.tensor), ncclSum,
GetNCCLDataType(first_entry.tensor), ncclOp,
*nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);
nccl_context_->ErrorCheck("ncclAllReduce", nccl_result,
*nccl_op_context_.nccl_comm_);
Expand All @@ -229,17 +243,17 @@ Status NCCLAllreduce::Execute(std::vector<TensorTableEntry>& entries,
// Copy (and possible scale) tensors out of the fusion buffer.
if (entries.size() > 1) {
ScaleMemcpyOutFusionBuffer(buffer_data, buffer_len,
response.postscale_factor(), entries);
postscale_factor, entries);

if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue,
MEMCPY_OUT_FUSION_BUFFER,
*gpu_op_context_.stream);
}
} else {
if (response.postscale_factor() != 1.0) {
if (postscale_factor != 1.0) {
// Execute postscaling op
ScaleBuffer(response.postscale_factor(), entries, buffer_data,
ScaleBuffer(postscale_factor, entries, buffer_data,
buffer_data, num_elements);
}
}
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/ops/nccl_operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 7, 0)
#define NCCL_P2P_SUPPORTED
#endif
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 10, 0)
#define NCCL_AVG_SUPPORTED
#endif
#elif HAVE_ROCM
#include <rccl.h>
#define NCCL_P2P_SUPPORTED
Expand Down
41 changes: 29 additions & 12 deletions horovod/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,16 @@ def __init__(self, optimizer, gradient_predivide_factor=1.0, num_groups=0, proce
raise ValueError('gradient_predivide_factor not supported yet with ROCm')

self._optimizer = optimizer
# Normalizing rescale_grad by Horovod size, which is equivalent to
# performing average in allreduce, has better performance.
self._optimizer.rescale_grad *= (gradient_predivide_factor / process_set.size())
self._gradient_predivide_factor = gradient_predivide_factor
self._average_in_framework=False
# C++ backend will apply additional 1 / size() factor to postscale_factor for op == Average.
self._postscale_factor = self._gradient_predivide_factor
if rocm_built() or nccl_built() < 21000:
# Perform average in framework via rescale_grad for ROCM or older NCCL versions
# without average support
self._optimizer.rescale_grad *= (gradient_predivide_factor / process_set.size())
self._postscale_factor = 1.0
self._average_in_framework=True
self._num_groups = num_groups
self._process_set = process_set

Expand All @@ -72,18 +78,21 @@ def _do_allreduce(self, index, grad):
index_split = split_list(index, self._num_groups)

for i, (grads, indices) in enumerate(zip(grad_split, index_split)):
grouped_allreduce_(tensors=grads, average=False, name="{}:{}".format(indices[0], indices[-1]), priority=-i,
grouped_allreduce_(tensors=grads, average=not self._average_in_framework, name="{}:{}".format(indices[0], indices[-1]), priority=-i,
prescale_factor=1.0 / self._gradient_predivide_factor,
postscale_factor=self._postscale_factor,
process_set=self._process_set)
else:
for i in range(len(index)):
allreduce_(grad[i], average=False,
allreduce_(grad[i], average=not self._average_in_framework,
name=str(index[i]), priority=-i,
prescale_factor=1.0 / self._gradient_predivide_factor,
postscale_factor=self._postscale_factor,
process_set=self._process_set)
else:
allreduce_(grad, average=False, name=str(index),
allreduce_(grad, average=not self._average_in_framework, name=str(index),
prescale_factor=1.0 / self._gradient_predivide_factor,
postscale_factor=self._postscale_factor,
process_set=self._process_set)

def update(self, index, weight, grad, state):
Expand Down Expand Up @@ -155,11 +164,16 @@ def __init__(self, params, optimizer, optimizer_params=None,
super(DistributedTrainer, self).__init__(
params, optimizer, optimizer_params=optimizer_params, kvstore=None)

# _scale is used to check and set rescale_grad for optimizer in Trainer.step()
# function. Normalizing it by Horovod size, which is equivalent to performing
# average in allreduce, has better performance.
self._scale *= (gradient_predivide_factor / process_set.size())
self._gradient_predivide_factor = gradient_predivide_factor
self._average_in_framework=False
# C++ backend will apply additional 1 / size() factor to postscale_factor for op == Average.
self._postscale_factor = self._gradient_predivide_factor
if rocm_built() or nccl_built() < 21000:
# Perform average in framework via rescale_grad for ROCM or older NCCL versions
# without average support
self._scale *= (gradient_predivide_factor / process_set.size())
self._postscale_factor = 1.0
self._average_in_framework=True
assert prefix is None or isinstance(prefix, str)
self._prefix = prefix if prefix else ""
self._num_groups = num_groups
Expand Down Expand Up @@ -193,8 +207,10 @@ def _allreduce_grads(self):

for entries in entries_by_dtype.values():
grads, names = zip(*entries)
grouped_allreduce_(tensors=grads, average=False, name="{}:{}".format(names[0], names[-1]), priority=-i,
grouped_allreduce_(tensors=grads, average=not self._average_in_framework,
name="{}:{}".format(names[0], names[-1]), priority=-i,
prescale_factor=1.0 / self._gradient_predivide_factor,
postscale_factor=self._postscale_factor,
process_set=self._process_set)

if self._compression != Compression.none:
Expand All @@ -208,9 +224,10 @@ def _allreduce_grads(self):
for i, param in enumerate(self._params):
if param.grad_req != 'null':
tensor_compressed, ctx = self._compression.compress(param.list_grad()[0])
allreduce_(tensor_compressed, average=False,
allreduce_(tensor_compressed, average=not self._average_in_framework,
name=self._prefix + str(i), priority=-i,
prescale_factor=1.0 / self._gradient_predivide_factor,
postscale_factor=self._postscale_factor,
process_set=self._process_set)

if self._compression != Compression.none:
Expand Down

0 comments on commit e6133e3

Please sign in to comment.