Skip to content

Commit

Permalink
Add hvd.grouped_allgather and hvd.grouped_reducescatter (#3594)
Browse files Browse the repository at this point in the history
Signed-off-by: Max H. Gerlach <git@maxgerlach.de>
  • Loading branch information
maxhgerlach committed Aug 2, 2022
1 parent 757883b commit 8f450ab
Show file tree
Hide file tree
Showing 30 changed files with 2,460 additions and 165 deletions.
11 changes: 8 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
### Added

- Added support for batched memory copies in GPUAllgather. ([#3590](https://github.com/horovod/horovod/pull/3590))
- Added support for batched memory copies in GPUReducescatter.
- Added support for batched memory copies in GPUReducescatter. ([#3621](https://github.com/horovod/horovod/pull/3621))
- Added `hvd.grouped_allgather()` and `hvd.grouped_reducescatter()` operations. ([#3594](https://github.com/horovod/horovod/pull/3594))
- TensorFlow: Added doc string for `hvd.grouped_allreduce()`. ([#3594](https://github.com/horovod/horovod/pull/3594))
- Added warning messages if output tensor memory allocations fail. ([#3594](https://github.com/horovod/horovod/pull/3594))

### Changed

Expand All @@ -19,8 +22,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

### Fixed

- Updated Eigen submodule to fix build on macOS with aarch64.
- Fix FuseResponses() on BATCHED_D2D_PADDING edge cases for Reducescatter and/or ROCm.
- Updated Eigen submodule to fix build on macOS with aarch64. ([#3619](https://github.com/horovod/horovod/pull/3619))
- Fix FuseResponses() on BATCHED_D2D_PADDING edge cases for Reducescatter and/or ROCm. ([#3621](https://github.com/horovod/horovod/pull/3621))
- PyTorch: Fixed Reducescatter functions to raise `HorovodInternalError` rather than `RuntimeError`. ([#3594](https://github.com/horovod/horovod/pull/3594))
- PyTorch on GPUs without GPU operations: Fixed grouped allreduce to set CPU device in tensor table. ([#3594](https://github.com/horovod/horovod/pull/3594))


## [v0.25.0] - 2022-06-20
Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
sphinx>=3,<4
sphinxcontrib-napoleon
alabaster
jinja2<3.1
nbsphinx
pyyaml
3 changes: 3 additions & 0 deletions horovod/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ struct TensorTableEntry {
std::shared_ptr<Tensor> tensor;
// Pre-allocated output tensor.
std::shared_ptr<Tensor> output;
// Grouped Reducescatter or Allgather ops will need to allocate memory for
// a specific output_index >= 0.
int32_t output_index = 0;
// Identifier for the subset of Horovod processes partaking in this operation.
int32_t process_set_id = 0;
// Root rank for broadcast operation (relative to process set).
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/nvtx_op_range.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ NvtxOpsHandle::NvtxOpsHandle() noexcept
REGISTER_STRING(HorovodAllreduce);
REGISTER_STRING(HorovodGroupedAllreduce);
REGISTER_STRING(HorovodAllgather);
REGISTER_STRING(HorovodGroupedAllgather);
REGISTER_STRING(HorovodBroadcast);
REGISTER_STRING(HorovodAlltoall);
REGISTER_STRING(HorovodReducescatter);
REGISTER_STRING(HorovodGroupedReducescatter);
#undef REGISTER_STRING
}

Expand Down
2 changes: 2 additions & 0 deletions horovod/common/nvtx_op_range.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ enum class RegisteredNvtxOp {
HorovodAllreduce = 0,
HorovodGroupedAllreduce,
HorovodAllgather,
HorovodGroupedAllgather,
HorovodBroadcast,
HorovodAlltoall,
HorovodReducescatter,
HorovodGroupedReducescatter,
// Insert values for new ops above this line. Also add corresponding
// REGISTER_STRING lines in the constructor NvtxOpsHandle::NvtxOpsHandle().
END,
Expand Down
221 changes: 178 additions & 43 deletions horovod/common/operations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1534,6 +1534,32 @@ Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
ReadyEventList ready_event_list,
const std::string& name, const int device,
StatusCallback callback, int32_t process_set_id) {
// Wrap inputs in std::vector and pass onto multi tensor implementation
std::vector<std::shared_ptr<OpContext>> contexts;
std::vector<std::shared_ptr<Tensor>> tensors;
std::vector<ReadyEventList> ready_event_lists;
std::vector<std::string> names;
std::vector<StatusCallback> callbacks;

contexts.emplace_back(std::move(context));
tensors.emplace_back(std::move(tensor));
ready_event_lists.emplace_back(std::move(ready_event_list));
names.emplace_back(std::move(name));
callbacks.emplace_back(std::move(callback));

return EnqueueTensorAllgathers(contexts, tensors, ready_event_lists, names,
device, callbacks, process_set_id);
}

// Contexts and controller must be initialized and the background thread
// must be running before this function is called.
Status
EnqueueTensorAllgathers(std::vector<std::shared_ptr<OpContext>>& contexts,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::vector<ReadyEventList>& ready_event_lists,
std::vector<std::string>& names, int device,
std::vector<StatusCallback>& callbacks,
int32_t process_set_id) {
if (horovod_global.cpu_operation == LibType::CCL && process_set_id > 0 &&
device == CPU_DEVICE_ID) {
return Status::InvalidArgument(
Expand All @@ -1555,31 +1581,72 @@ Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
" is not a member of the provided process set.");
}

Request message;
message.set_request_rank(process_set.controller->GetRank());
message.set_tensor_name(name);
message.set_tensor_type(tensor->dtype());
message.set_device(device);
message.set_request_type(Request::ALLGATHER);
for (int i = 0; i < tensor->shape().dims(); ++i) {
message.add_tensor_shape((int64_t)tensor->shape().dim_size(i));
std::vector<Request> messages;
std::vector<TensorTableEntry> entries;
messages.reserve(tensors.size());
entries.reserve(tensors.size());

for (int n = 0; n < (int)tensors.size(); ++n) {
Request message;
message.set_request_rank(process_set.controller->GetRank());
message.set_tensor_name(names[n]);
message.set_tensor_type(tensors[n]->dtype());
message.set_device(device);
message.set_request_type(Request::ALLGATHER);
message.set_tensor_shape(tensors[n]->shape().to_vector());

messages.push_back(std::move(message));

TensorTableEntry e;
e.tensor_name = names[n];
e.context = contexts[n];
e.tensor = tensors[n];
e.output_index = n;
e.process_set_id = process_set_id;
e.ready_event_list = std::move(ready_event_lists[n]);
e.device = device;
e.callback = std::move(callbacks[n]);

entries.push_back(std::move(e));
}

TensorTableEntry e;
e.tensor_name = name;
e.context = context;
e.tensor = tensor;
e.process_set_id = process_set_id;
e.ready_event_list = ready_event_list;
e.device = device;
e.callback = callback;
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodAllgather, e.tensor->size());
// Start appropriate NVTX range
if (tensors.size() == 1) {
auto& e = entries[0];
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodAllgather, e.tensor->size());
} else {
auto total_size =
std::accumulate(entries.begin(), entries.end(), 0ll,
[](int64_t size_sum, const TensorTableEntry& e) {
return size_sum + e.tensor->size();
});
SharedNvtxOpRange range;
range.Start(RegisteredNvtxOp::HorovodGroupedAllgather, total_size);
for (auto& e : entries) {
e.nvtx_op_range = range;
}
}

Status status = process_set.tensor_queue.AddToTensorQueue(e, message);
if (status.ok()) {
LOG(TRACE, horovod_global.global_controller->GetRank())
<< "Enqueued " << name;
std::string tensors_enqueued;
for (const auto& n : names) {
tensors_enqueued += n + "; ";
}
LOG(TRACE, horovod_global.global_controller->GetRank())
<< "Enqueued tensors for Allgather: " << tensors_enqueued;

// Only create groups larger than 1 tensor, unless disable_group_fusion is
// requested. In that case, even single tensor groups are created to enforce
// disabling fusion.
if (tensors.size() > 1 || horovod_global.disable_group_fusion) {
auto group_id = process_set.group_table.RegisterGroup(std::move(names));
for (auto& message : messages) {
message.set_group_id(group_id);
}
}

Status status =
process_set.tensor_queue.AddToTensorQueueMulti(entries, messages);

return status;
}

Expand Down Expand Up @@ -1661,6 +1728,33 @@ Status EnqueueTensorReducescatter(std::shared_ptr<OpContext> context,
const std::string& name, const int device,
StatusCallback callback, ReduceOp reduce_op,
int32_t process_set_id) {
// Wrap inputs in std::vector and pass onto multi tensor implementation
std::vector<std::shared_ptr<OpContext>> contexts;
std::vector<std::shared_ptr<Tensor>> tensors;
std::vector<ReadyEventList> ready_event_lists;
std::vector<std::string> names;
std::vector<StatusCallback> callbacks;

contexts.emplace_back(std::move(context));
tensors.emplace_back(std::move(tensor));
ready_event_lists.emplace_back(std::move(ready_event_list));
names.emplace_back(std::move(name));
callbacks.emplace_back(std::move(callback));

return EnqueueTensorReducescatters(contexts, tensors, ready_event_lists,
names, device, callbacks, reduce_op,
process_set_id);
}

// Contexts and controller must be initialized and the background thread
// must be running before this function is called.
Status
EnqueueTensorReducescatters(std::vector<std::shared_ptr<OpContext>>& contexts,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::vector<ReadyEventList>& ready_event_lists,
std::vector<std::string>& names, int device,
std::vector<StatusCallback>& callbacks,
ReduceOp reduce_op, int32_t process_set_id) {
if (horovod_global.cpu_operation == LibType::CCL && device == CPU_DEVICE_ID) {
return Status::InvalidArgument(
"Reducescatter is not supported yet with oneCCL operations.");
Expand All @@ -1670,6 +1764,7 @@ Status EnqueueTensorReducescatter(std::shared_ptr<OpContext> context,
"Reducescatter: Process set provided does not "
"exist, or has not been registered.");
}

if (reduce_op != ReduceOp::SUM) {
// Note: AVERAGE is supported by enqueuing SUM and performing divide at the
// framework level.
Expand All @@ -1689,32 +1784,72 @@ Status EnqueueTensorReducescatter(std::shared_ptr<OpContext> context,
" is not a member of the provided process set.");
}

Request message;
message.set_request_rank(process_set.controller->GetRank());
message.set_tensor_name(name);
message.set_tensor_type(tensor->dtype());
message.set_device(device);
message.set_request_type(Request::REDUCESCATTER);
for (int i = 0; i < tensor->shape().dims(); ++i) {
message.add_tensor_shape((int64_t)tensor->shape().dim_size(i));
std::vector<Request> messages;
std::vector<TensorTableEntry> entries;
messages.reserve(tensors.size());
entries.reserve(tensors.size());

for (int n = 0; n < (int)tensors.size(); ++n) {
Request message;
message.set_request_rank(process_set.controller->GetRank());
message.set_tensor_name(names[n]);
message.set_tensor_type(tensors[n]->dtype());
message.set_device(device);
message.set_request_type(Request::REDUCESCATTER);
message.set_tensor_shape(tensors[n]->shape().to_vector());
messages.push_back(std::move(message));

TensorTableEntry e;
e.tensor_name = names[n];
e.context = std::move(contexts[n]);
e.tensor = tensors[n];
e.output_index = n;
e.process_set_id = process_set_id;
e.ready_event_list = std::move(ready_event_lists[n]);
e.device = device;
e.callback = std::move(callbacks[n]);

entries.push_back(std::move(e));
}

TensorTableEntry e;
e.tensor_name = name;
e.context = context;
e.tensor = tensor;
e.process_set_id = process_set_id;
e.ready_event_list = ready_event_list;
e.device = device;
e.callback = callback;
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodReducescatter,
e.tensor->size());
// Start appropriate NVTX range
if (tensors.size() == 1) {
auto& e = entries[0];
e.nvtx_op_range.Start(RegisteredNvtxOp::HorovodReducescatter,
e.tensor->size());
} else {
auto total_size =
std::accumulate(entries.begin(), entries.end(), 0ll,
[](int64_t size_sum, const TensorTableEntry& e) {
return size_sum + e.tensor->size();
});
SharedNvtxOpRange range;
range.Start(RegisteredNvtxOp::HorovodGroupedReducescatter, total_size);
for (auto& e : entries) {
e.nvtx_op_range = range;
}
}

Status status = process_set.tensor_queue.AddToTensorQueue(e, message);
if (status.ok()) {
LOG(TRACE, horovod_global.global_controller->GetRank())
<< "Enqueued " << name;
std::string tensors_enqueued;
for (const auto& n : names) {
tensors_enqueued += n + "; ";
}
LOG(TRACE, horovod_global.global_controller->GetRank())
<< "Enqueued tensors for Reducescatter: " << tensors_enqueued;

// Only create groups larger than 1 tensor, unless disable_group_fusion is
// requested. In that case, even single tensor groups are created to enforce
// disabling fusion.
if (tensors.size() > 1 || horovod_global.disable_group_fusion) {
auto group_id = process_set.group_table.RegisterGroup(std::move(names));
for (auto& message : messages) {
message.set_group_id(group_id);
}
}

Status status =
process_set.tensor_queue.AddToTensorQueueMulti(entries, messages);

return status;
}

Expand Down
17 changes: 17 additions & 0 deletions horovod/common/operations.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ Status EnqueueTensorAllgather(std::shared_ptr<OpContext> context,
StatusCallback callback,
int32_t process_set_id = 0);

Status
EnqueueTensorAllgathers(std::vector<std::shared_ptr<OpContext>>& contexts,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::vector<ReadyEventList>& ready_event_lists,
std::vector<std::string>& names, int device,
std::vector<StatusCallback>& callbacks,
int32_t process_set_id = 0);

Status EnqueueTensorBroadcast(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> tensor,
std::shared_ptr<Tensor> output, int root_rank,
Expand All @@ -234,6 +242,15 @@ Status EnqueueTensorReducescatter(std::shared_ptr<OpContext> context,
ReduceOp reduce_op = ReduceOp::SUM,
int32_t process_set_id = 0);

Status
EnqueueTensorReducescatters(std::vector<std::shared_ptr<OpContext>>& contexts,
std::vector<std::shared_ptr<Tensor>>& tensors,
std::vector<ReadyEventList>& ready_event_lists,
std::vector<std::string>& names, int device,
std::vector<StatusCallback>& callbacks,
ReduceOp reduce_op = ReduceOp::SUM,
int32_t process_set_id = 0);

Status EnqueueJoin(std::shared_ptr<OpContext> context,
std::shared_ptr<Tensor> output_last_joined_rank,
ReadyEventList ready_event_list,
Expand Down

0 comments on commit 8f450ab

Please sign in to comment.