Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add barrier call to torch module to support easy synchronization for process sets #3139

Merged
merged 7 commits into from Oct 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -12,6 +12,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

- Added terminate_on_nan flag to Spark Lightning estimator. [#3088](https://github.com/horovod/horovod/issues/3088)

- Added barrier() API to torch module to support simple synchronization among ranks and to achieve parity with PyTorch DDP and similar frameworks. [#3139](https://github.com/horovod/horovod/pull/3139)

### Changed

- By default, RayExecutor will now use the current placement group instead of always creating a new one. ([#3134](https://github.com/horovod/horovod/pull/3134))
Expand Down
3 changes: 2 additions & 1 deletion horovod/common/basics.py
Expand Up @@ -150,7 +150,8 @@ def shutdown(self):

def is_initialized(self):
"""Returns True if Horovod is initialized"""
return self.MPI_LIB_CTYPES.horovod_is_initialized()
is_initialized = self.MPI_LIB_CTYPES.horovod_is_initialized()
return bool(is_initialized)

def start_timeline(self, file_path, mark_cycles=False):
"""Creates a timeline file at `file_path` and begins recording.
Expand Down
3 changes: 3 additions & 0 deletions horovod/common/common.h
Expand Up @@ -153,6 +153,9 @@ namespace common {
// Temporary tensor name for ranks that did Join().
#define JOIN_TENSOR_NAME "join.noname"

// Fixed tensor name for all barrier operations
#define BARRIER_TENSOR_NAME "barrier.noname"

// List of supported frameworks.
enum Framework { TENSORFLOW, PYTORCH, MXNET, XLA };

Expand Down
28 changes: 26 additions & 2 deletions horovod/common/controller.cc
Expand Up @@ -97,6 +97,12 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
continue;
}

// Never cache a barrier request, when all ranks return ready for this message, barrier will be released.
if(message.request_type() == Request::BARRIER) {
cache_coordinator.set_uncached_in_queue(true);
continue;
}

// Keep track of cache hits
if (response_cache_.capacity() > 0) {
auto cache_ = response_cache_.cached(message);
Expand Down Expand Up @@ -266,6 +272,13 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
}

bool reduce = IncrementTensorCount(message, process_set.joined_size);

// For barrier request, if not ready to reduce, we add it back to tensor queue
// to process in the next cycle.
if(!reduce && message.request_type() == Request::BARRIER) {
tensor_queue_.PushMessageToQueue(message);
}

stall_inspector_.RecordUncachedTensorStart(
message.tensor_name(), message.request_rank(), size_);
if (reduce) {
Expand All @@ -283,7 +296,6 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow
auto received_message_list = ready_list[i];
for (auto& received_message : received_message_list.requests()) {
auto& received_name = received_message.tensor_name();

if (received_message.request_type() == Request::JOIN) {
process_set.joined_size++;
process_set.last_joined_rank = global_ranks_[i];
Expand All @@ -292,6 +304,7 @@ ResponseList Controller::ComputeResponseList(bool this_process_requested_shutdow

bool reduce =
IncrementTensorCount(received_message, process_set.joined_size);

stall_inspector_.RecordUncachedTensorStart(
received_message.tensor_name(), received_message.request_rank(),
size_);
Expand Down Expand Up @@ -498,6 +511,7 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
// Check that all data types of tensors being processed
// are identical.
auto data_type = requests[0].tensor_type();

for (unsigned int i = 1; i < requests.size(); ++i) {
auto request_type = requests[i].tensor_type();
if (data_type != request_type) {
Expand Down Expand Up @@ -756,6 +770,8 @@ Response Controller::ConstructResponse(const std::string& name, int joined_size)
response.set_tensor_type(data_type);
response.set_prescale_factor(prescale_factor);
response.set_postscale_factor(postscale_factor);
} else if (message_type == Request::BARRIER) {
response.set_response_type(Response::BARRIER);
}
response.set_devices(devices);

Expand Down Expand Up @@ -971,13 +987,21 @@ bool Controller::IncrementTensorCount(const Request& msg, int joined_size) {
timeline_.NegotiateStart(name, msg.request_type());
} else {
std::vector<Request>& messages = table_iter->second;
messages.push_back(msg);
if(msg.request_type() == Request::BARRIER) {
if(tensor_queue_.IsTensorPresentInTable(name)) {
messages.push_back(msg);
}
}
else {
messages.push_back(msg);
}
}

timeline_.NegotiateRankReady(name, msg.request_rank());

std::vector<Request>& messages = table_iter->second;
int count = (int)messages.size();

bool ready_to_reduce = count == (size_ - joined_size);
if (ready_to_reduce) {
timeline_.NegotiateEnd(name);
Expand Down
6 changes: 6 additions & 0 deletions horovod/common/message.cc
Expand Up @@ -111,6 +111,9 @@ const std::string& Request::RequestType_Name(RequestType value) {
case RequestType::ALLTOALL:
static const std::string alltoall("ALLTOALL");
return alltoall;
case RequestType::BARRIER:
static const std::string barrier("BARRIER");
return barrier;
default:
static const std::string unknown("<unknown>");
return unknown;
Expand Down Expand Up @@ -300,6 +303,9 @@ const std::string& Response::ResponseType_Name(ResponseType value) {
case ResponseType::ALLTOALL:
static const std::string alltoall("ALLTOALL");
return alltoall;
case ResponseType::BARRIER:
static const std::string barrier("BARRIER");
return barrier;
case ResponseType::ERROR:
static const std::string error("ERROR");
return error;
Expand Down
4 changes: 2 additions & 2 deletions horovod/common/message.h
Expand Up @@ -50,7 +50,7 @@ std::size_t DataType_Size(DataType value);
class Request {
public:
enum RequestType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL = 5
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL = 5, BARRIER = 6
};


Expand Down Expand Up @@ -153,7 +153,7 @@ class RequestList {
class Response {
public:
enum ResponseType {
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL= 5, ERROR = 6
ALLREDUCE = 0, ALLGATHER = 1, BROADCAST = 2, JOIN = 3, ADASUM = 4, ALLTOALL= 5, BARRIER=6, ERROR = 7
};

static const std::string& ResponseType_Name(ResponseType value);
Expand Down
45 changes: 41 additions & 4 deletions horovod/common/operations.cc
Expand Up @@ -245,11 +245,12 @@ OperationManager* CreateOperationManager(HorovodGlobalState& state) {
#endif

std::shared_ptr<JoinOp> join_op(new JoinOp(&state));
std::shared_ptr<BarrierOp> barrier_op(new BarrierOp(&state));
std::shared_ptr<ErrorOp> error_op(new ErrorOp(&state));

return new OperationManager(&state.parameter_manager, allreduce_ops,
allgather_ops, broadcast_ops, alltoall_ops,
join_op, adasum_ops, error_op);
join_op, adasum_ops, barrier_op, error_op);
}

// Process a Response by doing a reduction, a gather, a broadcast, or
Expand All @@ -259,7 +260,9 @@ void PerformOperation(Response response, ProcessSet& process_set) {
auto& timeline = horovod_global.timeline;
process_set.tensor_queue.GetTensorEntriesFromResponse(response, entries,
process_set.joined);
if (response.response_type() != Response::JOIN) {

if (response.response_type() != Response::JOIN &&
response.response_type() != Response::BARRIER) {
for (auto& e : entries) {
timeline.Start(e.tensor_name, response.response_type(), e.tensor->size());
}
Expand Down Expand Up @@ -1016,8 +1019,8 @@ void horovod_shutdown() {
}
}

bool horovod_is_initialized() {
return horovod_global.initialization_done;
int horovod_is_initialized() {
return int(horovod_global.initialization_done.load());
}

int horovod_start_timeline(const char* file_name, bool mark_cycles) {
Expand Down Expand Up @@ -1755,5 +1758,39 @@ Status EnqueueJoin(std::shared_ptr<OpContext> context,
return status;
}

// Contexts and controller must be initialized and the background thread
// must be running before this function is called.
Status EnqueueBarrier(StatusCallback callback, int32_t process_set_id) {
auto& process_set = horovod_global.process_set_table.Get(process_set_id);

if (!process_set.IsCurrentProcessIncluded()) {
return Status::InvalidArgument(
"Barrier: Rank " +
std::to_string(horovod_global.global_controller->GetRank()) +
" is not a member of the provided process set.");
}

Request message;
// Barrier doesn't need a tensor, we set an arbitrary name for tracing purposes.
message.set_tensor_name(BARRIER_TENSOR_NAME);
message.set_request_rank(process_set.controller->GetRank());
message.set_request_type(Request::BARRIER);

TensorTableEntry e;
e.tensor_name = BARRIER_TENSOR_NAME;
e.process_set_id = process_set_id;
e.callback = callback;

if (horovod_global.shut_down) {
return SHUT_DOWN_ERROR;
}
Status status = process_set.tensor_queue.AddToTensorQueue(e, message);
if (status.ok()) {
LOG(TRACE, horovod_global.global_controller->GetRank()) << "Enqueued barrier op";
}

return status;
}

} // namespace common
} // namespace horovod
3 changes: 3 additions & 0 deletions horovod/common/operations.h
Expand Up @@ -233,6 +233,9 @@ Status EnqueueJoin(std::shared_ptr<OpContext> context,
StatusCallback callback,
int32_t process_set_id = 0);

Status EnqueueBarrier(StatusCallback callback,
int32_t process_set_id = 0);

} // namespace common
} // namespace horovod

Expand Down
16 changes: 16 additions & 0 deletions horovod/common/ops/collective_operations.cc
Expand Up @@ -312,6 +312,22 @@ Status JoinOp::Execute(std::vector<TensorTableEntry>& entries,
return Status::OK();
}

// Barrier
BarrierOp::BarrierOp(HorovodGlobalState* global_state) : HorovodOp(global_state) {}

Status BarrierOp::Execute(std::vector<TensorTableEntry>& entries,
const Response& response) {
assert(entries.size() == 1);
int& process_set_id = entries[0].process_set_id;
auto& process_set = global_state_->process_set_table.Get(process_set_id);


process_set.controller->Barrier(Communicator::GLOBAL);
LOG(TRACE, global_state_->global_controller->GetRank()) << "Released from barrier.";

return Status::OK();
}

// Error
ErrorOp::ErrorOp(HorovodGlobalState* global_state) : HorovodOp(global_state) {}

Expand Down
9 changes: 9 additions & 0 deletions horovod/common/ops/collective_operations.h
Expand Up @@ -289,6 +289,15 @@ class JoinOp : public HorovodOp {
const Response& response, ProcessSet& process_set);
};

class BarrierOp : public HorovodOp {
public:
explicit BarrierOp(HorovodGlobalState* global_state);

virtual ~BarrierOp() = default;

virtual Status Execute(std::vector<TensorTableEntry>& entries, const Response& response);
};

class ErrorOp : public HorovodOp {
public:
explicit ErrorOp(HorovodGlobalState* global_state);
Expand Down
9 changes: 9 additions & 0 deletions horovod/common/ops/operation_manager.cc
Expand Up @@ -27,6 +27,7 @@ OperationManager::OperationManager(ParameterManager* param_manager,
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops,
std::shared_ptr<JoinOp> join_op,
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops,
std::shared_ptr<BarrierOp> barrier_op,
std::shared_ptr<ErrorOp> error_op)
: param_manager_(param_manager),
allreduce_ops_(std::move(allreduce_ops)),
Expand All @@ -35,6 +36,7 @@ OperationManager::OperationManager(ParameterManager* param_manager,
alltoall_ops_(std::move(alltoall_ops)),
join_op_(std::move(join_op)),
adasum_ops_(std::move(adasum_ops)),
barrier_op_(std::move(barrier_op)),
error_op_(std::move(error_op)) {}

Status OperationManager::ExecuteAllreduce(std::vector<TensorTableEntry>& entries,
Expand Down Expand Up @@ -93,6 +95,11 @@ Status OperationManager::ExecuteAdasum(std::vector<TensorTableEntry>& entries,
throw std::logic_error("No Adasum operation enabled");
}

Status OperationManager::ExecuteBarrier(std::vector<TensorTableEntry>& entries,
const Response& response) const {
return barrier_op_->Execute(entries, response);
}

Status OperationManager::ExecuteError(std::vector<TensorTableEntry>& entries,
const Response& response) const {
return error_op_->Execute(entries, response);
Expand All @@ -113,6 +120,8 @@ Status OperationManager::ExecuteOperation(std::vector<TensorTableEntry>& entries
return ExecuteJoin(entries, response, process_set);
} else if (response.response_type() == Response::ADASUM) {
return ExecuteAdasum(entries, response);
} else if (response.response_type() == Response::BARRIER) {
return ExecuteBarrier(entries, response);
} else if (response.response_type() == Response::ERROR) {
return ExecuteError(entries, response);
} else {
Expand Down
4 changes: 4 additions & 0 deletions horovod/common/ops/operation_manager.h
Expand Up @@ -33,6 +33,7 @@ class OperationManager {
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops,
std::shared_ptr<JoinOp> join_op,
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops,
std::shared_ptr<BarrierOp> barrier_op,
std::shared_ptr<ErrorOp> error_op);

virtual ~OperationManager() = default;
Expand All @@ -52,6 +53,8 @@ class OperationManager {

Status ExecuteAdasum(std::vector<TensorTableEntry>& entries, const Response& response) const;

Status ExecuteBarrier(std::vector<TensorTableEntry>& entries, const Response& response) const;

Status ExecuteOperation(std::vector<TensorTableEntry>& entries,
const Response& response,
ProcessSet& process_set) const;
Expand All @@ -65,6 +68,7 @@ class OperationManager {
std::vector<std::shared_ptr<AlltoallOp>> alltoall_ops_;
std::shared_ptr<JoinOp> join_op_;
std::vector<std::shared_ptr<AllreduceOp>> adasum_ops_;
std::shared_ptr<BarrierOp> barrier_op_;
std::shared_ptr<ErrorOp> error_op_;
};

Expand Down
2 changes: 2 additions & 0 deletions horovod/common/response_cache.cc
Expand Up @@ -39,6 +39,8 @@ static Response::ResponseType RequestTypeToResponseType(Request::RequestType val
return Response::ResponseType::ADASUM;
case Request::RequestType::ALLTOALL:
return Response::ResponseType::ALLTOALL;
case Request::RequestType::BARRIER:
return Response::ResponseType::BARRIER;
default:
throw std::logic_error("No corresponding ResponseType for provided RequestType.");
}
Expand Down
9 changes: 8 additions & 1 deletion horovod/common/tensor_queue.cc
Expand Up @@ -109,17 +109,18 @@ void TensorQueue::GetTensorEntriesFromResponse(
response.response_type() == Response::BROADCAST ||
response.response_type() == Response::ALLTOALL ||
response.response_type() == Response::ADASUM ||
response.response_type() == Response::BARRIER ||
response.response_type() == Response::ERROR);

if (!joined) {
// We should never fail at finding this key in the tensor table.
auto iter = tensor_table_.find(name);
assert(iter != tensor_table_.end());

entries.push_back(std::move(iter->second));

// Clear the tensor table of this tensor.
tensor_table_.erase(iter);

} else if (response.response_type() != Response::ERROR) {

// Find Join tensor to use its context.
Expand Down Expand Up @@ -152,6 +153,12 @@ TensorQueue::GetTensorEntry(const std::string& tensor_name) const{
return iter;
}

bool TensorQueue::IsTensorPresentInTable(const std::string& tensor_name) const {
// Lock on the tensor table.
std::lock_guard<std::mutex> guard(mutex_);
return tensor_table_.find(tensor_name) != tensor_table_.end();
}

// Pop out all the messages from the queue
void TensorQueue::PopMessagesFromQueue(
std::deque<Request>& message_queue_buffer) {
Expand Down
2 changes: 2 additions & 0 deletions horovod/common/tensor_queue.h
Expand Up @@ -43,6 +43,8 @@ class TensorQueue {

const TensorTableEntry& GetTensorEntry(const std::string& tensor_name) const;

bool IsTensorPresentInTable (const std::string& tensor_name) const;

void PopMessagesFromQueue(std::deque<Request>& message_queue_buffer);

void PushMessageToQueue(Request& message);
Expand Down