Skip to content

Commit

Permalink
added barrier request to process it in background thread
Browse files Browse the repository at this point in the history
Signed-off-by: TJ <tix@uber.com>
  • Loading branch information
TJ committed Oct 4, 2021
1 parent e6d2190 commit b0f0edb
Show file tree
Hide file tree
Showing 15 changed files with 144 additions and 20 deletions.
3 changes: 3 additions & 0 deletions horovod/common/common.h
Expand Up @@ -153,6 +153,9 @@ namespace common {
// Temporary tensor name for ranks that did Join().
#define JOIN_TENSOR_NAME "join.noname"

// Fixed tensor name for all barrier operations
#define BARRIER_TENSOR_NAME "barrier.noname"

// List of supported frameworks.
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA };

Expand Down
34 changes: 32 additions & 2 deletions horovod/common/controller.cc
Expand Up @@ -97,6 +97,12 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
continue;
}

// Never cache a barrier request, when all ranks return ready for this message, barrier will be released.
if(message.request_type() == Request::BARRIER) {
cache_coordinator.set_uncached_in_queue(true);
continue;
}

// Keep track of cache hits
if (response_cache_.capacity() > 0) {
auto cache_ = response_cache_.cached(message);
Expand Down Expand Up @@ -266,6 +272,13 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
}

bool reduce = IncrementTensorCount(message, process_set.joined_size);

// For barrier request, if not ready to reduce, we add it back to tensor queue
// to process in the next cycle.
if(!reduce && message.request_type() == Request::BARRIER) {
tensor_queue_.PushMessageToQueue(message);
}

stall_inspector_.RecordUncachedTensorStart(
message.tensor_name(), message.request_rank(), size_);
if (reduce) {
Expand All @@ -283,7 +296,6 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
auto received_message_list = ready_list[i];
for (auto& received_message : received_message_list.requests()) {
auto& received_name = received_message.tensor_name();

if (received_message.request_type() == Request::JOIN) {
process_set.joined_size++;
process_set.last_joined_rank = global_ranks_[i];
Expand All @@ -292,6 +304,13 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow

bool reduce =
IncrementTensorCount(received_message, process_set.joined_size);
// For barrier request, if not ready to reduce, we add it back to tensor queue
// to process in the next cycle.
if(!reduce && received_message.request_type() == Request::BARRIER) {
Request barrier_msg = received_message;
tensor_queue_.PushMessageToQueue(barrier_msg);
}

stall_inspector_.RecordUncachedTensorStart(
received_message.tensor_name(), received_message.request_rank(),
size_);
Expand Down Expand Up @@ -498,6 +517,7 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
// Check that all data types of tensors being processed
// are identical.
auto data_type = requests[0].tensor_type();

for (unsigned int i = 1; i < requests.size(); ++i) {
auto request_type = requests[i].tensor_type();
if (data_type != request_type) {
Expand Down Expand Up @@ -756,6 +776,8 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
} else if (message_type == Request::BARRIER) {
response.set_response_type(Response::BARRIER);
}
response.set_devices(devices);

Expand Down Expand Up @@ -971,13 +993,21 @@ bool Controller::IncrementTensorCount(const Request& msg, int joined_size) {
timeline_.NegotiateStart(name, msg.request_type());
} else {
std::vector<Request>& messages = table_iter->second;
messages.push_back(msg);
if(msg.request_type() == Request::BARRIER) {
if(tensor_queue_.IsTensorPresentInTable(name)) {
messages.push_back(msg);
}
}
else {
messages.push_back(msg);
}
}

timeline_.NegotiateRankReady(name, msg.request_rank());

std::vector<Request>& messages = table_iter->second;
int count = (int)messages.size();

bool ready_to_reduce = count == (size_ - joined_size);
if (ready_to_reduce) {
timeline_.NegotiateEnd(name);
Expand Down
6 changes: 6 additions & 0 deletions horovod/common/message.cc
Expand Up @@ -111,6 +111,9 @@ const std::string& Request::RequestType_Name(RequestType value) {
case RequestType::ALLTOALL:
static const std::string alltoall("ALLTOALL");
return alltoall;
case RequestType::BARRIER:
static const std::string barrier("BARRIER");
return barrier;
default:
static const std::string unknown("<unknown>");
return unknown;
Expand Down Expand Up @@ -300,6 +303,9 @@ const std::string& Response::ResponseType_Name(ResponseType value) {
case ResponseType::ALLTOALL:
static const std::string alltoall("ALLTOALL");
return alltoall;
case ResponseType::BARRIER:
static const std::string barrier("BARRIER");
return barrier;
case ResponseType::ERROR:
static const std::string error("ERROR");
return error;
Expand Down
4 changes: 2 additions & 2 deletions horovod/common/message.h
Expand Up @@ -50,7 +50,7 @@ std::size_t DataType_Size(DataType value);
class Request {
public:
enum RequestType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL = 5
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL = 5, BARRIER = 6
};


Expand Down Expand Up @@ -153,7 +153,7 @@ class RequestList {
class Response {
public:
enum ResponseType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL= 5, ERROR = 6
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL= 5, BARRIER=6, ERROR = 7
};

static const std::string& ResponseType_Name(ResponseType value);
Expand Down
33 changes: 27 additions & 6 deletions horovod/common/operations.cc
Expand Up @@ -245,11 +245,12 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
#endif

std::shared_ptr<JoinOp> join_op(new JoinOp(&state));
std::shared_ptr<BarrierOp> barrier_op(new BarrierOp(&state));
std::shared_ptr<ErrorOp> error_op(new ErrorOp(&state));

return new OperationManager(&state.parameter_manager, allreduce_ops,
allgather_ops, broadcast_ops, alltoall_ops,
join_op, adasum_ops, error_op);
join_op, adasum_ops, barrier_op, error_op);
}

// Process a Response by doing a reduction, a gather, a broadcast, or
Expand All @@ -259,7 +260,9 @@ void PerformOperation(Response response, ProcessSet& process_set) {
auto& timeline = horovod_global.timeline;
process_set.tensor_queue.GetTensorEntriesFromResponse(response, entries,
process_set.joined);
if (response.response_type() != Response::JOIN) {

if (response.response_type() != Response::JOIN &&
response.response_type() != Response::BARRIER) {
for (auto& e : entries) {
timeline.Start(e.tensor_name, response.response_type(), e.tensor->size());
}
Expand Down Expand Up @@ -1757,18 +1760,36 @@ Status EnqueueJoin(std::shared_ptr<OpContext> context,

// Contexts and controller must be initialized and the background thread
// must be running before this function is called.
Status CallBarrier(int32_t process_set_id) {
Status EnqueueBarrier(StatusCallback callback, int32_t process_set_id) {
auto& process_set = horovod_global.process_set_table.Get(process_set_id);

if (!process_set.IsCurrentProcessIncluded()) {
return Status::InvalidArgument(
"Barrier: Rank " +
std::to_string(horovod_global.global_controller->GetRank()) +
" is not a member of the provided process set.");
}

process_set.controller->Barrier(Communicator::GLOBAL);
LOG(TRACE, horovod_global.global_controller->GetRank()) << "Released from barrier.";
return Status::OK();
Request message;
// Barrier doesn't need a tensor, we set an arbitrary name for tracing purposes.
message.set_tensor_name(BARRIER_TENSOR_NAME);
message.set_request_rank(process_set.controller->GetRank());
message.set_request_type(Request::BARRIER);

TensorTableEntry e;
e.tensor_name = BARRIER_TENSOR_NAME;
e.process_set_id = process_set_id;
e.callback = callback;

if (horovod_global.shut_down) {
return SHUT_DOWN_ERROR;
}
Status status = process_set.tensor_queue.AddToTensorQueue(e, message);
if (status.ok()) {
LOG(TRACE, horovod_global.global_controller->GetRank()) << "Enqueued barrier op";
}

return status;
}

} // namespace common
Expand Down
3 changes: 2 additions & 1 deletion horovod/common/operations.h
Expand Up @@ -233,7 +233,8 @@ Status EnqueueJoin(std::shared_ptr<OpContext> context,
StatusCallback callback,
int32_t process_set_id = 0);

Status CallBarrier(int32_t process_set_id = 0);
Status EnqueueBarrier(StatusCallback callback,
int32_t process_set_id = 0);

} // namespace common
} // namespace horovod
Expand Down
16 changes: 16 additions & 0 deletions horovod/common/ops/collective_operations.cc
Expand Up @@ -312,6 +312,22 @@ Status JoinOp::Execute(std::vector<TensorTableEntry>& entries,
return Status::OK();
}

// Barrier
BarrierOp::BarrierOp(HorovodGlobalState* global_state) : HorovodOp(global_state) {}

Status BarrierOp::Execute(std::vector<TensorTableEntry>& entries,
const Response& response) {
assert(entries.size() == 1);
int& process_set_id = entries[0].process_set_id;
auto& process_set = global_state_->process_set_table.Get(process_set_id);


process_set.controller->Barrier(Communicator::GLOBAL);
LOG(TRACE, global_state_->global_controller->GetRank()) << "Released from barrier.";

return Status::OK();
}

// Error
ErrorOp::ErrorOp(HorovodGlobalState* global_state) : HorovodOp(global_state) {}

Expand Down
9 changes: 9 additions & 0 deletions horovod/common/ops/collective_operations.h
Expand Up @@ -289,6 +289,15 @@ class JoinOp : public HorovodOp {
const Response& response, ProcessSet& process_set);
};

class BarrierOp : public HorovodOp {
public:
explicit BarrierOp(HorovodGlobalState* global_state);

virtual ~BarrierOp() = default;

virtual Status Execute(std::vector<TensorTableEntry>& entries, const Response& response);
};

class ErrorOp : public HorovodOp {
public:
explicit ErrorOp(HorovodGlobalState* global_state);
Expand Down
9 changes: 9 additions & 0 deletions horovod/common/ops/operation_manager.cc
Expand Up @@ -27,6 +27,7 @@ OperationManager::OperationManager(ParameterManager* param_manager,
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops,
std::shared_ptr<JoinOp> join_op,
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops,
std::shared_ptr<BarrierOp> barrier_op,
std::shared_ptr<ErrorOp> error_op)
: param_manager_(param_manager),
allreduce_ops_(std::move(allreduce_ops)),
Expand All @@ -35,6 +36,7 @@ OperationManager::OperationManager(ParameterManager* param_manager,
alltoall_ops_(std::move(alltoall_ops)),
join_op_(std::move(join_op)),
adasum_ops_(std::move(adasum_ops)),
barrier_op_(std::move(barrier_op)),
error_op_(std::move(error_op)) {}

Status OperationManager::ExecuteAllreduce(std::vector<TensorTableEntry>& entries,
Expand Down Expand Up @@ -93,6 +95,11 @@ Status OperationManager::ExecuteAdasum(std::vector<TensorTableEntry>& entries,
throw std::logic_error("No Adasum operation enabled");
}

Status OperationManager::ExecuteBarrier(std::vector<TensorTableEntry>& entries,
const Response& response) const {
return barrier_op_->Execute(entries, response);
}

Status OperationManager::ExecuteError(std::vector<TensorTableEntry>& entries,
const Response& response) const {
return error_op_->Execute(entries, response);
Expand All @@ -113,6 +120,8 @@ Status OperationManager::ExecuteOperation(std::vector<TensorTableEntry>& entries
return ExecuteJoin(entries, response, process_set);
} else if (response.response_type() == Response::ADASUM) {
return ExecuteAdasum(entries, response);
} else if (response.response_type() == Response::BARRIER) {
return ExecuteBarrier(entries, response);
} else if (response.response_type() == Response::ERROR) {
return ExecuteError(entries, response);
} else {
Expand Down
4 changes: 4 additions & 0 deletions horovod/common/ops/operation_manager.h
Expand Up @@ -33,6 +33,7 @@ class OperationManager {
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops,
std::shared_ptr<JoinOp> join_op,
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops,
std::shared_ptr<BarrierOp> barrier_op,
std::shared_ptr<ErrorOp> error_op);

virtual ~OperationManager() = default;
Expand All @@ -52,6 +53,8 @@ class OperationManager {

Status ExecuteAdasum(std::vector<TensorTableEntry>& entries, const Response& response) const;

Status ExecuteBarrier(std::vector<TensorTableEntry>& entries, const Response& response) const;

Status ExecuteOperation(std::vector<TensorTableEntry>& entries,
const Response& response,
ProcessSet& process_set) const;
Expand All @@ -65,6 +68,7 @@ class OperationManager {
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops_;
std::shared_ptr<JoinOp> join_op_;
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops_;
std::shared_ptr<BarrierOp> barrier_op_;
std::shared_ptr<ErrorOp> error_op_;
};

Expand Down
2 changes: 2 additions & 0 deletions horovod/common/response_cache.cc
Expand Up @@ -39,6 +39,8 @@ static Response::ResponseType RequestTypeToResponseType(Request::RequestType val
return Response::ResponseType::ADASUM;
case Request::RequestType::ALLTOALL:
return Response::ResponseType::ALLTOALL;
case Request::RequestType::BARRIER:
return Response::ResponseType::BARRIER;
default:
throw std::logic_error("No corresponding ResponseType for provided RequestType.");
}
Expand Down
9 changes: 8 additions & 1 deletion horovod/common/tensor_queue.cc
Expand Up @@ -109,17 +109,18 @@ void TensorQueue::GetTensorEntriesFromResponse(
response.response_type() == Response::BROADCAST ||
response.response_type() == Response::ALLTOALL ||
response.response_type() == Response::ADASUM ||
response.response_type() == Response::BARRIER ||
response.response_type() == Response::ERROR);

if (!joined) {
// We should never fail at finding this key in the tensor table.
auto iter = tensor_table_.find(name);
assert(iter != tensor_table_.end());

entries.push_back(std::move(iter->second));

// Clear the tensor table of this tensor.
tensor_table_.erase(iter);

} else if (response.response_type() != Response::ERROR) {

// Find Join tensor to use its context.
Expand Down Expand Up @@ -152,6 +153,12 @@ TensorQueue::GetTensorEntry(const std::string& tensor_name) const{
return iter;
}

bool TensorQueue::IsTensorPresentInTable(const std::string& tensor_name) const {
// Lock on the tensor table.
std::lock_guard<std::mutex> guard(mutex_);
return tensor_table_.find(tensor_name) != tensor_table_.end();
}

// Pop out all the messages from the queue
void TensorQueue::PopMessagesFromQueue(
std::deque<Request>& message_queue_buffer) {
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/tensor_queue.h
Expand Up @@ -43,6 +43,8 @@ class TensorQueue {

const TensorTableEntry& GetTensorEntry(const std::string& tensor_name) const;

bool IsTensorPresentInTable (const std::string& tensor_name) const;

void PopMessagesFromQueue(std::deque<Request>& message_queue_buffer);

void PushMessageToQueue(Request& message);
Expand Down
6 changes: 5 additions & 1 deletion horovod/torch/mpi_ops.py
Expand Up @@ -983,6 +983,10 @@ def barrier(process_set=global_process_set):
"""

try:
mpi_lib.horovod_torch_barrier(process_set.process_set_id)
handle = mpi_lib.horovod_torch_barrier(process_set.process_set_id)
except RuntimeError as e:
raise HorovodInternalError(e)

_handle_map[handle] = (None, None)

synchronize(handle)

0 comments on commit b0f0edb

Please sign in to comment.