Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement 2D torus allreduce using NCCL #3608

Merged
merged 4 commits into from Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions horovod/common/common.h
Expand Up @@ -129,6 +129,7 @@ namespace common {
#define HOROVOD_STALL_SHUTDOWN_TIME_SECONDS "HOROVOD_STALL_SHUTDOWN_TIME_SECONDS"
#define HOROVOD_HIERARCHICAL_ALLREDUCE "HOROVOD_HIERARCHICAL_ALLREDUCE"
#define HOROVOD_HIERARCHICAL_ALLGATHER "HOROVOD_HIERARCHICAL_ALLGATHER"
#define HOROVOD_TORUS_ALLREDUCE "HOROVOD_TORUS_ALLREDUCE"
#define HOROVOD_CACHE_CAPACITY "HOROVOD_CACHE_CAPACITY"
#define HOROVOD_BATCH_D2D_MEMCOPIES "HOROVOD_BATCH_D2D_MEMCOPIES"
#define HOROVOD_NUM_NCCL_STREAMS "HOROVOD_NUM_NCCL_STREAMS"
Expand Down
25 changes: 25 additions & 0 deletions horovod/common/operations.cc
Expand Up @@ -129,6 +129,8 @@ GPUContext gpu_context;

#if HAVE_NCCL
NCCLContext nccl_context;
NCCLContext local_nccl_context;
NCCLContext cross_nccl_context;
#endif

#if HAVE_DDL
Expand Down Expand Up @@ -190,6 +192,8 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
#endif

#if HAVE_NCCL && HOROVOD_GPU_ALLREDUCE == 'N'
allreduce_ops.push_back(std::shared_ptr<AllreduceOp>(
new NCCLTorusAllreduce(&local_nccl_context, &cross_nccl_context, &gpu_context, &state)));
allreduce_ops.push_back(std::shared_ptr<AllreduceOp>(
new NCCLAllreduce(&nccl_context, &gpu_context, &state)));
#endif
Expand Down Expand Up @@ -462,7 +466,11 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {

#if HAVE_NCCL
nccl_context.nccl_comms.resize(state.num_nccl_streams);
local_nccl_context.nccl_comms.resize(state.num_nccl_streams);
cross_nccl_context.nccl_comms.resize(state.num_nccl_streams);
SetBoolFromEnv(HOROVOD_ELASTIC, nccl_context.elastic, true);
SetBoolFromEnv(HOROVOD_ELASTIC, local_nccl_context.elastic, true);
SetBoolFromEnv(HOROVOD_ELASTIC, cross_nccl_context.elastic, true);
#endif
gpu_context.streams.resize(state.num_nccl_streams);

Expand Down Expand Up @@ -577,6 +585,21 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
"allgather and hierarchical allreduce.";
}

// Set flag for torus allreduce. Ignore if Horovod is running on a
// single node.
auto horovod_torus_allreduce =
std::getenv(HOROVOD_TORUS_ALLREDUCE);
state.parameter_manager.SetTorusAllreduce(false);
if (horovod_torus_allreduce != nullptr) {
bool value = std::strtol(horovod_torus_allreduce, nullptr, 10) > 0 &&
(size != local_size);
state.parameter_manager.SetTorusAllreduce(value, true);
}
#if HOROVOD_GPU_ALLREDUCE != 'N' && HOROVOD_GPU_ALLREDUCE != 'D'
// Torus allreduce is not supported without NCCL or DDL
state.parameter_manager.SetTorusAllreduce(false, true);
#endif

// Set flag to control use of batched memcopy kernel on GPU
auto horovod_batch_d2d_memcopies = std::getenv(HOROVOD_BATCH_D2D_MEMCOPIES);
if (horovod_batch_d2d_memcopies != nullptr &&
Expand Down Expand Up @@ -675,6 +698,8 @@ void BackgroundThreadLoop(HorovodGlobalState& state) {
// Finalize all contexts
#if HAVE_NCCL
nccl_context.ShutDown();
local_nccl_context.ShutDown();
cross_nccl_context.ShutDown();
#endif

LOG(DEBUG, horovod_global.global_controller->GetRank())
Expand Down
256 changes: 253 additions & 3 deletions horovod/common/ops/nccl_operations.cc
Expand Up @@ -149,9 +149,8 @@ void NCCLOpContext::PopulateNCCLCommStrategy(int& nccl_rank, int& nccl_size,
nccl_rank = process_set.controller->GetLocalRank();
nccl_size = process_set.controller->GetLocalSize();
} else {
throw std::logic_error("Communicator type " +
std::to_string(communicator_type_) +
" is not supported in NCCL mode.");
nccl_rank = process_set.controller->GetCrossRank();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you modify this to check for else if (communicator_type_ == Communicator::CROSS) explicitly to make this more clear?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much for the review! I've updated the diff with the suggestions and tested with bert pretraining and it seems all correct.

nccl_size = process_set.controller->GetCrossSize();
}
nccl_id_bcast_comm = communicator_type_;
}
Expand Down Expand Up @@ -516,6 +515,257 @@ bool NCCLHierarchicalAllreduce::Enabled(
}
#endif

void NCCLTorusAllreduce::WaitForData(std::vector<TensorTableEntry>& entries) {
if (global_state_->timeline.Initialized()) {
// If timeline is initialized, need to use normal CPU syncing path
HorovodOp::WaitForData(entries);
} else {
// Push events to set to deduplicate entries
std::unordered_set<gpuEvent_t> event_set;
for (auto& e : entries) {
e.ready_event_list.PushEventsToSet(event_set);
}
for (auto& ev : event_set) {
HVD_GPU_CHECK(gpuStreamWaitEvent(*gpu_op_context_.stream, ev, 0));
}
}
}

Status NCCLTorusAllreduce::Execute(std::vector<TensorTableEntry>& entries,
const Response& response) {
assert(!entries.empty());
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved
auto& first_entry = entries.at(0);
auto& process_set =
global_state_->process_set_table.Get(first_entry.process_set_id);

// Determine GPU IDs of the devices participating in this communicator.
std::vector<int32_t> local_nccl_device_map;
local_nccl_device_map.reserve(process_set.controller->GetLocalCommRanks().size());
for (int rank : process_set.controller->GetLocalCommRanks()) {
int32_t device = response.devices().at(rank);
local_nccl_device_map.push_back(device);
}
gpu_op_context_.InitGPU(entries);
local_nccl_op_context_.InitNCCLComm(entries, local_nccl_device_map);

std::vector<int32_t> cross_nccl_device_map({response.devices()[process_set.controller->GetCrossRank()]});
cross_nccl_op_context_.InitNCCLComm(entries, cross_nccl_device_map);
gpu_op_context_.InitGPUQueue(entries, response);

WaitForData(entries);

const void* fused_input_data;
void* buffer_data;
size_t buffer_len;

// Copy memory into the fusion buffer.
if (entries.size() > 1) {
MemcpyInFusionBuffer(entries, fused_input_data, buffer_data, buffer_len);

if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue,
MEMCPY_IN_FUSION_BUFFER,
*gpu_op_context_.stream);
}
} else {
fused_input_data = first_entry.tensor->data();
buffer_data = (void*)first_entry.output->data();
buffer_len = (size_t)first_entry.output->size();
}

int64_t num_elements =
buffer_len / DataType_Size(first_entry.tensor->dtype());

if (response.prescale_factor() != 1.0) {
// Execute prescaling op
ScaleBuffer(response.prescale_factor(), entries, fused_input_data,
buffer_data, num_elements);
fused_input_data = buffer_data; // for unfused, scale is done out of place
}
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved

// Do allreduce.
int element_size = DataType_Size(first_entry.tensor->dtype());
int local_size = process_set.controller->GetLocalSize();
int local_rank = process_set.controller->GetLocalRank();

// If cluster is homogeneous and we are using fusion buffer, include
// dummy elements from the buffer (if necessary) to make sure the data
// is divisible by local_size. This is always possible since we
// set the fusion buffer size divisible by local_size.
if (process_set.controller->IsHomogeneous() && entries.size() > 1) {
// Making sure the number of elements is divisible by
// FUSION_BUFFER_ATOMIC_UNIT for improved performance
int div = local_size * FUSION_BUFFER_ATOMIC_UNIT;
num_elements = ((num_elements + div - 1) / div) * div;
buffer_len = num_elements * element_size;
}

// Split the elements into two groups: num_elements_per_rank*local_size,
// and num_elements_remaining. Cross-node reduction for the first group
// is done by all local_rank's in parallel, while for the second group
// it it is only done by the root_rank. If the cluster is not
// homogeneous first group is zero, and root_rank is 0.

// Homogeneous case:
// For the part of data divisible by local_size, perform NCCL
// ReduceScatter - Parallelized MPI Allreduce - NCCL Allgather. For the
// non-divisible part (if any), do NCCL Reduce (at rank local_size-1),
// MPI Allreduce (across rank (local_size-1)'s), and NCCL Bcast
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved

int64_t num_elements_per_rank =
process_set.controller->IsHomogeneous() ? num_elements / local_size : 0;

size_t buffer_len_per_rank = element_size * num_elements_per_rank;

void* buffer_data_at_rank_offset =
(uint8_t*)buffer_data + buffer_len_per_rank * local_rank;

int64_t num_elements_remaining = process_set.controller->IsHomogeneous()
? num_elements % local_size
: num_elements;

size_t buffer_len_remaining = element_size * num_elements_remaining;

void* buffer_data_remainder =
(uint8_t*)buffer_data + buffer_len_per_rank * local_size;

void* fused_input_data_remainder =
(uint8_t*)fused_input_data + buffer_len_per_rank * local_size;

int root_rank = process_set.controller->IsHomogeneous() ? local_size - 1 : 0;
bool is_root_rank = local_rank == root_rank;

int64_t total_num_elements =
is_root_rank ? num_elements_per_rank + num_elements_remaining
: num_elements_per_rank;
int64_t total_buffer_len = is_root_rank
? buffer_len_per_rank + buffer_len_remaining
: buffer_len_per_rank;

auto& timeline = global_state_->timeline;
if (num_elements_per_rank > 0) {
auto nccl_result = ncclReduceScatter(
fused_input_data, buffer_data_at_rank_offset,
(size_t)num_elements_per_rank, GetNCCLDataType(first_entry.tensor),
ncclSum, *local_nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);
local_nccl_context_->ErrorCheck("ncclReduceScatter", nccl_result,
*local_nccl_op_context_.nccl_comm_);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCESCATTER,
*gpu_op_context_.stream);
}
}

if (num_elements_remaining > 0) {
// Reduce the remaining data at local_size-1 to append to
// existing buffer
auto nccl_result =
ncclReduce(fused_input_data_remainder, buffer_data_remainder,
(size_t)num_elements_remaining,
GetNCCLDataType(first_entry.tensor), ncclSum, root_rank,
*local_nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);
local_nccl_context_->ErrorCheck("ncclReduce", nccl_result,
*local_nccl_op_context_.nccl_comm_);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_REDUCE,
*gpu_op_context_.stream);
}
}

if (process_set.controller->IsHomogeneous() || is_root_rank) {
// Synchronize.
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved
if (global_state_->elastic_enabled) {
gpu_context_->WaitForEventsElastic(
gpu_op_context_.event_queue, entries, timeline,
local_nccl_op_context_.error_check_callback_);
} else {
gpu_context_->WaitForEvents(gpu_op_context_.event_queue, entries,
timeline,
local_nccl_op_context_.error_check_callback_);
}
timeline.ActivityStartAll(entries, NCCL_ALLREDUCE);
auto cross_nccl_result = ncclAllReduce(buffer_data_at_rank_offset, buffer_data_at_rank_offset,
(size_t) total_num_elements, GetNCCLDataType(first_entry.tensor),
ncclSum, *cross_nccl_op_context_.nccl_comm_, *gpu_op_context_.stream);
cross_nccl_context_->ErrorCheck("ncclAllReduce", cross_nccl_result, *cross_nccl_op_context_.nccl_comm_);
timeline.ActivityEndAll(entries);
}
// We need to make sure the cross-node ncclAllReduce doesn't raise async error
if (global_state_->elastic_enabled) {
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved
gpu_context_->WaitForEventsElastic(
gpu_op_context_.event_queue, entries, timeline,
cross_nccl_op_context_.error_check_callback_);
} else {
gpu_context_->WaitForEvents(gpu_op_context_.event_queue, entries,
timeline,
cross_nccl_op_context_.error_check_callback_);
}
if (num_elements_per_rank > 0) {
local_nccl_context_->ErrorCheck(
"ncclAllGather",
ncclAllGather(buffer_data_at_rank_offset, buffer_data,
(size_t)num_elements_per_rank,
GetNCCLDataType(first_entry.tensor),
*local_nccl_op_context_.nccl_comm_, *gpu_op_context_.stream),
*local_nccl_op_context_.nccl_comm_);
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_ALLGATHER,
*gpu_op_context_.stream);
}
}
if (num_elements_remaining > 0) {
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 2, 12)
local_nccl_context_->ErrorCheck(
"ncclBroadcast",
ncclBroadcast(buffer_data_remainder, buffer_data_remainder,
(size_t)num_elements_remaining,
GetNCCLDataType(first_entry.tensor), root_rank,
*local_nccl_op_context_.nccl_comm_, *gpu_op_context_.stream),
*local_nccl_op_context_.nccl_comm_);
#else
local_nccl_context_->ErrorCheck(
"ncclBcast",
ncclBcast(buffer_data_remainder, (size_t)num_elements_remaining,
GetNCCLDataType(first_entry.tensor), root_rank,
*local_nccl_op_context_.nccl_comm_, *gpu_op_context_.stream),
*local_nccl_op_context_.nccl_comm_);
#endif
if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue, NCCL_BCAST,
*gpu_op_context_.stream);
}
}

if (response.postscale_factor() != 1.0) {
// Execute postscaling op
ScaleBuffer(response.postscale_factor(), entries, buffer_data, buffer_data,
num_elements);
}

// Copy memory out of the fusion buffer.
if (entries.size() > 1) {
MemcpyOutFusionBuffer(buffer_data, entries);

if (global_state_->timeline.Initialized()) {
gpu_context_->RecordEvent(gpu_op_context_.event_queue,
MEMCPY_OUT_FUSION_BUFFER,
*gpu_op_context_.stream);
}
}
maxhgerlach marked this conversation as resolved.
Show resolved Hide resolved

return gpu_op_context_.FinalizeGPUQueue(
entries, true, local_nccl_op_context_.error_check_callback_);
}

bool NCCLTorusAllreduce::Enabled(const ParameterManager& param_manager,
const std::vector<TensorTableEntry>& entries,
const Response& response) const {
if (!GPUAllreduce::Enabled(param_manager, entries, response)) {
return false;
}
return param_manager.TorusAllreduce();
}

void NCCLBroadcast::WaitForData(std::vector<TensorTableEntry>& entries) {
if (global_state_->timeline.Initialized()) {
// If timeline is initialized, need to use normal CPU syncing path
Expand Down
28 changes: 28 additions & 0 deletions horovod/common/ops/nccl_operations.h
Expand Up @@ -240,6 +240,34 @@ class NCCLHierarchicalAllreduce : public NCCLAllreduce {
};
#endif

class NCCLTorusAllreduce : public GPUAllreduce {
public:
NCCLTorusAllreduce(NCCLContext* local_nccl_context, NCCLContext* cross_nccl_context,
GPUContext* gpu_context, HorovodGlobalState* global_state)
: GPUAllreduce(gpu_context, global_state),
local_nccl_context_(local_nccl_context),
cross_nccl_context_(cross_nccl_context),
local_nccl_op_context_(local_nccl_context, global_state, Communicator::LOCAL),
cross_nccl_op_context_(cross_nccl_context, global_state, Communicator::CROSS),
global_state_(global_state){};

Status Execute(std::vector<TensorTableEntry>& entries,
const Response& response) override;

bool Enabled(const ParameterManager& param_manager,
const std::vector<TensorTableEntry>& entries,
const Response& response) const override;

protected:
void WaitForData(std::vector<TensorTableEntry>& entries) override;

NCCLContext* local_nccl_context_;
NCCLOpContext local_nccl_op_context_;
NCCLContext* cross_nccl_context_;
NCCLOpContext cross_nccl_op_context_;
HorovodGlobalState* global_state_;
};

class NCCLAllgather : public GPUAllgather {
public:
NCCLAllgather(NCCLContext* nccl_context, GPUContext* gpu_context,
Expand Down