Skip to content

Commit

Permalink
Add multi-tensor hvd.grouped_allreduce API. (#2453)
Browse files Browse the repository at this point in the history
Signed-off-by: Josh Romero <joshr@nvidia.com>
  • Loading branch information
romerojosh committed Nov 19, 2020
1 parent c7848ca commit 775959d
Show file tree
Hide file tree
Showing 35 changed files with 2,049 additions and 274 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Expand Up @@ -49,6 +49,7 @@ include_directories("third_party/HTTPRequest/include"
list(APPEND SOURCES "${PROJECT_SOURCE_DIR}/horovod/common/common.cc"
"${PROJECT_SOURCE_DIR}/horovod/common/controller.cc"
"${PROJECT_SOURCE_DIR}/horovod/common/fusion_buffer_manager.cc"
"${PROJECT_SOURCE_DIR}/horovod/common/group_table.cc"
"${PROJECT_SOURCE_DIR}/horovod/common/half.cc"
"${PROJECT_SOURCE_DIR}/horovod/common/logging.cc"
"${PROJECT_SOURCE_DIR}/horovod/common/message.cc"
Expand Down
6 changes: 4 additions & 2 deletions horovod/_keras/__init__.py
Expand Up @@ -28,7 +28,8 @@
def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
compression, sparse_as_dense, gradient_predivide_factor,
op, backward_passes_per_step=1,
average_aggregated_gradients=False):
average_aggregated_gradients=False,
num_groups=0):
class _DistributedOptimizer(keras.optimizers.Optimizer):
_HAS_AGGREGATE_GRAD = True

Expand All @@ -43,7 +44,8 @@ def __init__(self, **kwargs):
compression,
sparse_as_dense,
op,
gradient_predivide_factor)
gradient_predivide_factor,
num_groups)

self._agg_helper = None
if backward_passes_per_step > 1:
Expand Down
1 change: 1 addition & 0 deletions horovod/common/common.h
Expand Up @@ -89,6 +89,7 @@ namespace common {
#define HOROVOD_GLOO "GLOO"
#define HOROVOD_ADASUM_MPI_CHUNK_SIZE "HOROVOD_ADASUM_MPI_CHUNK_SIZE"
#define HOROVOD_THREAD_AFFINITY "HOROVOD_THREAD_AFFINITY"
#define HOROVOD_DISABLE_GROUP_FUSION "HOROVOD_DISABLE_GROUP_FUSION"

// String constant for gloo interface.
#define GLOO_DEFAULT_IFACE ""
Expand Down
103 changes: 94 additions & 9 deletions horovod/common/controller.cc
Expand Up @@ -53,10 +53,11 @@ void Controller::SynchronizeParameters() {
}

Controller::Controller(ResponseCache& response_cache, TensorQueue& tensor_queue,
Timeline& timeline, ParameterManager& parameter_manager)
Timeline& timeline, ParameterManager& parameter_manager,
GroupTable& group_table)
: stall_inspector_(response_cache), tensor_queue_(tensor_queue),
timeline_(timeline), response_cache_(response_cache),
parameter_manager_(parameter_manager) {}
parameter_manager_(parameter_manager), group_table_(group_table) {}

void Controller::Initialize() {
response_cache_.clear();
Expand Down Expand Up @@ -194,6 +195,34 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
// If all messages in queue have responses in cache, use fast path with
// no additional coordination.

// If group fusion is disabled, fuse tensors in groups separately
if (state.disable_group_fusion && !group_table_.empty()) {
// Note: need group order to be based on position in cache for global consistency
std::vector<int> common_ready_groups;
std::unordered_set<int> processed;
for (auto bit : cache_coordinator.cache_hits()) {
const auto& tensor_name = response_cache_.peek_response(bit).tensor_names()[0];
int group_id = group_table_.GetGroupIDFromTensorName(tensor_name);
if (group_id != NULL_GROUP_ID && processed.find(group_id) == processed.end()) {
common_ready_groups.push_back(group_id);
processed.insert(group_id);
}
}

for (auto id : common_ready_groups) {
std::deque<Response> responses;
for (const auto &tensor_name : group_table_.GetGroupTensorNames(id)) {
auto bit = response_cache_.peek_cache_bit(tensor_name);
responses.push_back(response_cache_.get_response(bit));
// Erase cache hit to avoid processing a second time.
cache_coordinator.erase_hit(bit);
}

FuseResponses(responses, state, response_list);
}
}


std::deque<Response> responses;
// Convert cache hits to responses. Populate so that least
// recently used responses get priority. All workers call the code
Expand All @@ -204,7 +233,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
}

// Fuse responses as normal.
response_list = FuseResponses(responses, state);
FuseResponses(responses, state, response_list);
response_list.set_shutdown(cache_coordinator.should_shut_down());
} else {
// There are uncached messages coming in, need communication to figure out
Expand Down Expand Up @@ -278,6 +307,56 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
}
}

// Fuse tensors in groups before processing others.
if (state.disable_group_fusion && !group_table_.empty()) {

// Extract set of common groups from coordinator tensor list and cache hits.
std::vector<int> common_ready_groups;
std::unordered_set<int> processed;

for (const auto& tensor_name : ready_to_reduce) {
int group_id = group_table_.GetGroupIDFromTensorName(tensor_name);
if (group_id != NULL_GROUP_ID && processed.find(group_id) == processed.end()) {
common_ready_groups.push_back(group_id);
processed.insert(group_id);
// Leaving name in list, to be skipped later.
}
}

if (response_cache_.capacity() > 0) {
for (auto bit : cache_coordinator.cache_hits()) {
const auto& tensor_name = response_cache_.peek_response(bit).tensor_names()[0];
int group_id = group_table_.GetGroupIDFromTensorName(tensor_name);
if (group_id != NULL_GROUP_ID && processed.find(group_id) == processed.end()) {
common_ready_groups.push_back(group_id);
processed.insert(group_id);
}
}
}

// For each ready group, form and fuse response lists independently
for (auto id : common_ready_groups) {
std::deque<Response> responses;
for (const auto &tensor_name : group_table_.GetGroupTensorNames(id)) {
if (message_table_.find(tensor_name) != message_table_.end()) {
// Uncached message
Response response = ConstructResponse(tensor_name, state.joined_size);
responses.push_back(std::move(response));

} else {
// Cached message
auto bit = response_cache_.peek_cache_bit(tensor_name);
responses.push_back(response_cache_.get_response(bit));
// Erase cache hit to avoid processing a second time.
cache_coordinator.erase_hit(bit);
}
}

FuseResponses(responses, state, response_list);
}
}


// At this point, rank zero should have a fully updated tensor count
// table and should know all the tensors that need to be reduced or
// gathered, and everyone else should have sent all their information
Expand All @@ -300,6 +379,13 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
}

for (auto& tensor_name : ready_to_reduce) {
// Skip tensors in group that were handled earlier.
if (state.disable_group_fusion &&
!group_table_.empty() &&
group_table_.GetGroupIDFromTensorName(tensor_name) != NULL_GROUP_ID) {
continue;
}

Response response = ConstructResponse(tensor_name, state.joined_size);
responses.push_back(std::move(response));
}
Expand All @@ -311,7 +397,7 @@ ResponseList Controller::ComputeResponseList(std::atomic_bool& shut_down,
responses.push_back(std::move(join_response));
state.joined_size = 0;
}
response_list = FuseResponses(responses, state);
FuseResponses(responses, state, response_list);
response_list.set_shutdown(should_shut_down);

// Broadcast final results to other ranks.
Expand Down Expand Up @@ -382,7 +468,7 @@ int64_t Controller::TensorFusionThresholdBytes() {
return proposed_fusion_threshold;
}

Response Controller::ConstructResponse(std::string& name, int joined_size) {
Response Controller::ConstructResponse(const std::string& name, int joined_size) {
bool error = false;
auto it = message_table_.find(name);
assert(it != message_table_.end());
Expand Down Expand Up @@ -688,9 +774,9 @@ void Controller::CoordinateCacheAndState(CacheCoordinator& cache_coordinator) {
}
}

ResponseList Controller::FuseResponses(std::deque<Response>& responses,
HorovodGlobalState& state) {
ResponseList response_list;
void Controller::FuseResponses(std::deque<Response>& responses,
HorovodGlobalState& state,
ResponseList& response_list) {
while (!responses.empty()) {

auto response = responses.front();
Expand Down Expand Up @@ -825,7 +911,6 @@ ResponseList Controller::FuseResponses(std::deque<Response>& responses,
response_list.add_response(std::move(response));
LOG(TRACE) << "Created response of size " << tensor_size;
}
return response_list;
}

int64_t Controller::TotalByteSizeOfAllgatherOutput(
Expand Down
13 changes: 9 additions & 4 deletions horovod/common/controller.h
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "global_state.h"
#include "group_table.h"
#include "parameter_manager.h"
#include "response_cache.h"
#include "stall_inspector.h"
Expand All @@ -36,7 +37,8 @@ using MessageTable = std::unordered_map<std::string, std::vector<Request>>;
class Controller : public std::enable_shared_from_this<Controller> {
public:
Controller(ResponseCache& response_cache, TensorQueue& tensor_queue,
Timeline& timeline, ParameterManager& parameter_manager);
Timeline& timeline, ParameterManager& parameter_manager,
GroupTable& group_table);

Controller(const Controller&) = delete;

Expand Down Expand Up @@ -155,15 +157,16 @@ class Controller : public std::enable_shared_from_this<Controller> {
// also contains error messages in case the submitted Requests were not
// valid (for example, contained mismatched shapes or types).
// Constructing the Response, thus, requires a whole lot of error checking.
Response ConstructResponse(std::string& name, int joined_size = 0);
Response ConstructResponse(const std::string& name, int joined_size = 0);

// Routine to sync cache hit and invalid bit sets across workers.
// Also determines global shutdown state and whether uncached requests
// exist on any worker.
void CoordinateCacheAndState(CacheCoordinator& cache_coordinator);

ResponseList FuseResponses(std::deque<Response>& responses,
HorovodGlobalState& state);
void FuseResponses(std::deque<Response>& responses,
HorovodGlobalState& state,
ResponseList& response_list);

// Return the total byte size of the final allgathered output tensor
int64_t
Expand Down Expand Up @@ -215,6 +218,8 @@ class Controller : public std::enable_shared_from_this<Controller> {
ResponseCache& response_cache_;

ParameterManager& parameter_manager_;

GroupTable& group_table_;
};

} // namespace common
Expand Down
7 changes: 7 additions & 0 deletions horovod/common/global_state.h
Expand Up @@ -21,6 +21,7 @@
#include <thread>

#include "fusion_buffer_manager.h"
#include "group_table.h"
#include "parameter_manager.h"
#include "response_cache.h"
#include "tensor_queue.h"
Expand Down Expand Up @@ -93,6 +94,9 @@ struct HorovodGlobalState {
// Index of current GPU stream to use
int current_nccl_stream = 0;

// Information on registered groups.
GroupTable group_table;

// A LibType indicating what framework we are using to perform CPU operations.
LibType cpu_operation;

Expand All @@ -113,6 +117,9 @@ struct HorovodGlobalState {
// Enable use of batched d2d memcopy kernel on GPU
bool batch_d2d_memcopies = true;

// Flag indicating whether to prohibit groups from fusing
bool disable_group_fusion = false;

~HorovodGlobalState() {
// Make sure that the destructor of the background thread is safe to
// call. If a thread is still joinable (not detached or complete) its
Expand Down
5 changes: 3 additions & 2 deletions horovod/common/gloo/gloo_controller.h
Expand Up @@ -27,8 +27,9 @@ class GlooController : public Controller {
public:
GlooController(ResponseCache& response_cache, TensorQueue& tensor_queue,
Timeline& timeline, ParameterManager& parameter_manager,
GlooContext& gloo_context)
: Controller(response_cache, tensor_queue, timeline, parameter_manager),
GroupTable& group_table, GlooContext& gloo_context)
: Controller(response_cache, tensor_queue, timeline, parameter_manager,
group_table),
gloo_context_(gloo_context) {};

int GetTypeSize(DataType dtype) override;
Expand Down
82 changes: 82 additions & 0 deletions horovod/common/group_table.cc
@@ -0,0 +1,82 @@
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================

#include "group_table.h"

#include <assert.h>

namespace horovod {
namespace common {

int32_t GroupTable::GetGroupIDFromTensorName(const std::string& tensor_name) const {
std::lock_guard<std::mutex> guard(mutex_);
auto it = tensor_name_to_id_.find(tensor_name);
if (it != tensor_name_to_id_.end())
return it->second;
else {
return NULL_GROUP_ID;
}
}

const std::vector<std::string>& GroupTable::GetGroupTensorNames(int32_t group_id) const {
std::lock_guard<std::mutex> guard(mutex_);
return id_to_tensor_names_.at(group_id);
}

bool GroupTable::empty(void) const {
std::lock_guard<std::mutex> guard(mutex_);
return tensor_name_to_id_.empty();
}

int32_t GroupTable::RegisterGroup(std::vector<std::string>&& tensor_names) {
std::lock_guard<std::mutex> guard(mutex_);

int32_t group_id;
if (!free_ids_.empty()) {
// Reuse old group_id
group_id = free_ids_.front();
free_ids_.pop();
} else {
// Create a new group_id
group_id = next_group_id_++;
}

for (auto& name : tensor_names) {
tensor_name_to_id_.emplace(name, group_id);
}
id_to_tensor_names_.emplace(group_id, std::move(tensor_names));

return group_id;
}

void GroupTable::DeregisterGroups(const std::vector<std::string>& tensor_names) {
std::lock_guard<std::mutex> guard(mutex_);

for (auto& name : tensor_names) {
auto it = tensor_name_to_id_.find(name);
if (it != tensor_name_to_id_.end()) {
auto group_id = it->second;
for (auto& entry : id_to_tensor_names_[group_id]) {
tensor_name_to_id_.erase(entry);
}
id_to_tensor_names_.erase(group_id);

free_ids_.push(group_id);
}
}
}

} // namespace common
} // namespace horovod

0 comments on commit 775959d

Please sign in to comment.