Skip to content

Commit

Permalink
Add support for additional reduction operations for allreduce (min, m…
Browse files Browse the repository at this point in the history
…ax, product). (#3660)

* Add support for additional reduction operations for allreduce (min, max, product).

Signed-off-by: Josh Romero <joshr@nvidia.com>

* Fix compilation errors.

Signed-off-by: Josh Romero <joshr@nvidia.com>

* Handle reduce_op in FuseResponses.

Signed-off-by: Josh Romero <joshr@nvidia.com>

* Fix gloo header.

Signed-off-by: Josh Romero <joshr@nvidia.com>

* Try slice instead of ellipsis in MXNet tests.

Signed-off-by: Josh Romero <joshr@nvidia.com>

* Update more asserts in MXNet mpi_ops.py.

Signed-off-by: Josh Romero <joshr@nvidia.com>

* Addressing comments.

Signed-off-by: Josh Romero <joshr@nvidia.com>

Signed-off-by: Josh Romero <joshr@nvidia.com>
  • Loading branch information
romerojosh committed Sep 19, 2022
1 parent 4f723bb commit 427b633
Show file tree
Hide file tree
Showing 31 changed files with 964 additions and 148 deletions.
3 changes: 3 additions & 0 deletions horovod/common/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ def __init__(self, pkg_path, *args):
self.Average = self.MPI_LIB_CTYPES.horovod_reduce_op_average()
self.Sum = self.MPI_LIB_CTYPES.horovod_reduce_op_sum()
self.Adasum = self.MPI_LIB_CTYPES.horovod_reduce_op_adasum()
self.Min = self.MPI_LIB_CTYPES.horovod_reduce_op_min()
self.Max = self.MPI_LIB_CTYPES.horovod_reduce_op_max()
self.Product = self.MPI_LIB_CTYPES.horovod_reduce_op_product()

# These must be kept in sync with operations.cc (this might also be possible via ctypes)
self.HOROVOD_PROCESS_SET_ERROR_INIT = -1
Expand Down
26 changes: 25 additions & 1 deletion horovod/common/controller.cc
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,28 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
}
}

// If we are doing an allreduce, check that the reduction op is indentical
// across ranks.
ReduceOp reduce_op;
if (message_type == Request::ALLREDUCE) {
reduce_op = requests[0].reduce_op();
for (unsigned int i = 1; i < requests.size(); ++i) {
if (error) {
break;
}
ReduceOp request_reduce_op = requests[i].reduce_op();
if (reduce_op != request_reduce_op) {
error = true;
error_message_stream
<< "Mismatched reduction operation: "
<< "One rank sent reduction op " << reduce_op
<< ", but another rank sent reduction op "
<< request_reduce_op << ".";
break;
}
}
}

std::vector<int64_t> tensor_sizes;
if (message_type == Request::ALLGATHER ||
message_type == Request::ALLTOALL) {
Expand Down Expand Up @@ -768,6 +790,7 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
response.set_reduce_op(reduce_op);
} else if (message_type == Request::BROADCAST) {
response.set_response_type(Response::BROADCAST);
} else if (message_type == Request::ALLTOALL) {
Expand Down Expand Up @@ -877,7 +900,8 @@ void Controller::FuseResponses(std::deque<Response>& responses,
response.tensor_type() == new_response.tensor_type() &&
tensor_size + new_tensor_size <= TensorFusionThresholdBytes() &&
response.prescale_factor() == new_response.prescale_factor() &&
response.postscale_factor() == new_response.postscale_factor()) {
response.postscale_factor() == new_response.postscale_factor() &&
response.reduce_op() == new_response.reduce_op()) {
// These tensors will fuse together well.
tensor_size += new_tensor_size;
response.add_tensor_name(std::move(new_response.tensor_names()[0]));
Expand Down
71 changes: 71 additions & 0 deletions horovod/common/half.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

#include "half.h"

#include <algorithm>

#if __AVX__ && __F16C__
#include <cpuid.h>
#include <immintrin.h>
Expand Down Expand Up @@ -74,6 +76,75 @@ void float16_sum(void* invec, void* inoutvec, int* len,
Float2HalfBits(&inout_float, inout + i);
}
}

// float16 custom data type min operation.
void float16_min(void* invec, void* inoutvec, int* len,
MPI_Datatype* datatype) {
// cast invec and inoutvec to your float16 type
auto* in = (unsigned short*)invec;
auto* inout = (unsigned short*)inoutvec;

for (int i = 0; i < *len; ++i) {
float in_float;
float inout_float;
HalfBits2Float(in + i, &in_float);
HalfBits2Float(inout + i, &inout_float);
inout_float = std::min(inout_float, in_float);
Float2HalfBits(&inout_float, inout + i);
}
}

// float16 custom data type max operation.
void float16_max(void* invec, void* inoutvec, int* len,
MPI_Datatype* datatype) {
// cast invec and inoutvec to your float16 type
auto* in = (unsigned short*)invec;
auto* inout = (unsigned short*)inoutvec;

for (int i = 0; i < *len; ++i) {
float in_float;
float inout_float;
HalfBits2Float(in + i, &in_float);
HalfBits2Float(inout + i, &inout_float);
inout_float = std::max(inout_float, in_float);
Float2HalfBits(&inout_float, inout + i);
}
}

// float16 custom data type product operation.
void float16_prod(void* invec, void* inoutvec, int* len,
MPI_Datatype* datatype) {
// cast invec and inoutvec to your float16 type
auto* in = (unsigned short*)invec;
auto* inout = (unsigned short*)inoutvec;

int i = 0;
#if __AVX__ && __F16C__
if (is_avx_and_f16c()) {
for (; i < (*len / 8) * 8; i += 8) {
// convert in & inout to m256
__m256 in_m256 = _mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(in + i)));
__m256 inout_m256 =
_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(inout + i)));

// add them together to new_inout_m256
__m256 new_inout_m256 = _mm256_mul_ps(in_m256, inout_m256);

// convert back and store in inout
__m128i new_inout_m128i = _mm256_cvtps_ph(new_inout_m256, 0);
_mm_storeu_si128((__m128i*)(inout + i), new_inout_m128i);
}
}
#endif
for (; i < *len; ++i) {
float in_float;
float inout_float;
HalfBits2Float(in + i, &in_float);
HalfBits2Float(inout + i, &inout_float);
inout_float *= in_float;
Float2HalfBits(&inout_float, inout + i);
}
}
#endif

} // namespace common
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ inline void Float2HalfBits(const float* src, unsigned short* dest) {

#if HAVE_MPI
void float16_sum(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype);
void float16_min(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype);
void float16_max(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype);
void float16_prod(void* invec, void* inoutvec, int* len, MPI_Datatype* datatype);
#endif

} // namespace common
Expand Down
13 changes: 13 additions & 0 deletions horovod/common/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ void Request::add_tensor_shape(int64_t value) {
tensor_shape_.push_back(value);
}

ReduceOp Request::reduce_op() const { return reduce_op_; };

void Request::set_reduce_op(ReduceOp reduce_op) { reduce_op_ = reduce_op; };


namespace {

void Request_ParseFromWire(Request& request,
Expand All @@ -189,6 +194,7 @@ void Request_ParseFromWire(Request& request,
obj->tensor_shape()->end()));
request.set_prescale_factor(obj->prescale_factor());
request.set_postscale_factor(obj->postscale_factor());
request.set_reduce_op((ReduceOp) obj->reduce_op());
}

void Request_SerializeToWire(const Request& request,
Expand All @@ -209,6 +215,7 @@ void Request_SerializeToWire(const Request& request,
request_builder.add_tensor_shape(tensor_shape_wire);
request_builder.add_prescale_factor(request.prescale_factor());
request_builder.add_postscale_factor(request.postscale_factor());
request_builder.add_reduce_op((wire::ReduceOp) request.reduce_op());
obj = request_builder.Finish();
}

Expand Down Expand Up @@ -419,6 +426,10 @@ void Response::set_last_joined_rank(int value) {
last_joined_rank_ = value;
}

ReduceOp Response::reduce_op() const { return reduce_op_; };

void Response::set_reduce_op(ReduceOp reduce_op) { reduce_op_ = reduce_op; };

void Response_ParseFromWire(Response& response,
const wire::Response* obj) {
response.set_response_type((Response::ResponseType) obj->response_type());
Expand All @@ -434,6 +445,7 @@ void Response_ParseFromWire(Response& response,
response.set_prescale_factor(obj->prescale_factor());
response.set_postscale_factor(obj->postscale_factor());
response.set_last_joined_rank(obj->last_joined_rank());
response.set_reduce_op((ReduceOp) obj->reduce_op());
}

void Response::ParseFromBytes(Response& response, const uint8_t* input) {
Expand Down Expand Up @@ -463,6 +475,7 @@ void Response_SerializeToWire(const Response& response,
response_builder.add_prescale_factor(response.prescale_factor());
response_builder.add_postscale_factor(response.postscale_factor());
response_builder.add_last_joined_rank(response.last_joined_rank());
response_builder.add_reduce_op((wire::ReduceOp) response.reduce_op());
obj = response_builder.Finish();
}

Expand Down
21 changes: 21 additions & 0 deletions horovod/common/message.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ enum DataType {
HOROVOD_BOOL = 9,
};

enum ReduceOp {
AVERAGE = 0,
SUM = 1,
ADASUM = 2,
MIN = 3,
MAX = 4,
PRODUCT = 5,
};

const std::string& DataType_Name(DataType value);

std::size_t DataType_Size(DataType value);
Expand Down Expand Up @@ -91,7 +100,9 @@ class Request {
void set_device(int32_t value);

int32_t group_id() const;

void set_group_id(int32_t value);

const std::vector<int64_t>& tensor_shape() const;

void set_tensor_shape(const std::vector<int64_t>& value);
Expand All @@ -106,6 +117,10 @@ class Request {

void set_postscale_factor(double postscale_factor);

ReduceOp reduce_op() const;

void set_reduce_op(ReduceOp reduce_op);

static void ParseFromBytes(Request& request, const uint8_t* input);

static void SerializeToString(const Request& request, std::string& output);
Expand All @@ -124,6 +139,7 @@ class Request {
std::vector<int64_t> tensor_shape_;
double prescale_factor_ = 1.0;
double postscale_factor_ = 1.0;
ReduceOp reduce_op_ = ReduceOp::SUM;
};

class RequestList {
Expand Down Expand Up @@ -226,6 +242,10 @@ class Response {

void set_last_joined_rank(int value);

ReduceOp reduce_op() const;

void set_reduce_op(ReduceOp reduce_op);

static void ParseFromBytes(Response& response, const uint8_t* input);

static void SerializeToString(const Response& response,
Expand All @@ -241,6 +261,7 @@ class Response {
double prescale_factor_ = 1.0;
double postscale_factor_ = 1.0;
int last_joined_rank_ = -1;
ReduceOp reduce_op_ = ReduceOp::SUM;
};

class ResponseList {
Expand Down
43 changes: 39 additions & 4 deletions horovod/common/mpi/mpi_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,18 @@ MPI_Op MPIContext::GetMPISumOp(DataType dtype) const {
return dtype == HOROVOD_FLOAT16 ? mpi_float16_sum : MPI_SUM;
}

MPI_Op MPIContext::GetMPIMinOp(DataType dtype) const {
return dtype == HOROVOD_FLOAT16 ? mpi_float16_min : MPI_MIN;
}

MPI_Op MPIContext::GetMPIMaxOp(DataType dtype) const {
return dtype == HOROVOD_FLOAT16 ? mpi_float16_max : MPI_MAX;
}

MPI_Op MPIContext::GetMPIProdOp(DataType dtype) const {
return dtype == HOROVOD_FLOAT16 ? mpi_float16_prod : MPI_PROD;
}

MPI_Comm MPIContext::GetMPICommunicator(Communicator comm) const {
switch (comm) {
case GLOBAL:
Expand All @@ -86,14 +98,26 @@ int MPIContext::GetMPITypeSize(DataType dtype) const {

namespace {

void CreateMPIFloat16TypeAndSumOp(MPI_Datatype& mpi_float16_t,
MPI_Op& mpi_float16_sum) {
void CreateMPIFloat16TypeAndOps(MPI_Datatype& mpi_float16_t,
MPI_Op& mpi_float16_sum,
MPI_Op& mpi_float16_min,
MPI_Op& mpi_float16_max,
MPI_Op& mpi_float16_prod) {
// Create custom MPI float16 data type.
MPI_Type_contiguous(2, MPI_BYTE, &mpi_float16_t);
MPI_Type_commit(&mpi_float16_t);

// Create custom MPI float16 summation op.
MPI_Op_create(&float16_sum, 1, &mpi_float16_sum);

// Create custom MPI float16 min op.
MPI_Op_create(&float16_min, 1, &mpi_float16_min);

// Create custom MPI float16 max op.
MPI_Op_create(&float16_max, 1, &mpi_float16_max);

// Create custom MPI float16 prod op.
MPI_Op_create(&float16_prod, 1, &mpi_float16_prod);
}

void CreateMPILocalAndCrossComm(MPI_Comm mpi_comm, MPI_Comm& local_comm,
Expand Down Expand Up @@ -176,7 +200,8 @@ void MPIContext::Initialize(MPIContextManager& ctx_manager) {

CreateMPILocalAndCrossComm(mpi_comm, local_comm, cross_comm);

CreateMPIFloat16TypeAndSumOp(mpi_float16_t, mpi_float16_sum);
CreateMPIFloat16TypeAndOps(mpi_float16_t, mpi_float16_sum, mpi_float16_min,
mpi_float16_max, mpi_float16_prod);
}

void MPIContext::InitializeForProcessSet(const MPIContext& global_context,
Expand Down Expand Up @@ -207,7 +232,8 @@ void MPIContext::InitializeForProcessSet(const MPIContext& global_context,
CreateMPILocalAndCrossComm(mpi_comm, local_comm, cross_comm);
}

CreateMPIFloat16TypeAndSumOp(mpi_float16_t, mpi_float16_sum);
CreateMPIFloat16TypeAndOps(mpi_float16_t, mpi_float16_sum, mpi_float16_min,
mpi_float16_max, mpi_float16_prod);
}

void MPIContext::Finalize(MPIContextManager& ctx_manager) {
Expand Down Expand Up @@ -243,6 +269,15 @@ void MPIContext::FinalizeWithoutEnv() {
if (mpi_float16_sum != MPI_OP_NULL) {
MPI_Op_free(&mpi_float16_sum);
}
if (mpi_float16_min != MPI_OP_NULL) {
MPI_Op_free(&mpi_float16_min);
}
if (mpi_float16_max != MPI_OP_NULL) {
MPI_Op_free(&mpi_float16_max);
}
if (mpi_float16_prod != MPI_OP_NULL) {
MPI_Op_free(&mpi_float16_prod);
}
}

void MPIContextManager::EnvInitialize(int mpi_threads_required) {
Expand Down
9 changes: 9 additions & 0 deletions horovod/common/mpi/mpi_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ struct MPIContext {

MPI_Op GetMPISumOp(DataType dtype) const;

MPI_Op GetMPIMinOp(DataType dtype) const;

MPI_Op GetMPIMaxOp(DataType dtype) const;

MPI_Op GetMPIProdOp(DataType dtype) const;

// Communicators handled here are restricted to a single process set.
// If the running process is not part of that set, these communicators
// remain MPI_COMM_NULL.
Expand All @@ -88,6 +94,9 @@ struct MPIContext {
// MPI custom data type for float16.
MPI_Datatype mpi_float16_t;
MPI_Op mpi_float16_sum;
MPI_Op mpi_float16_min;
MPI_Op mpi_float16_max;
MPI_Op mpi_float16_prod;

// Private MPI communicator for Horovod to ensure no collisions with other
// threads using MPI, incorporates all processes known to Horovod.
Expand Down

0 comments on commit 427b633

Please sign in to comment.