Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Gloo controller #1181

Merged
merged 8 commits into from Aug 10, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 30 additions & 0 deletions NOTICE
Expand Up @@ -131,3 +131,33 @@
The derived work can be found in the files:

- examples/keras_spark_rossmann.py

elnormous/HTTPRequest
Copyright (c) 2017, Elviss Strazdiņš
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

The derived work can be found in the files:

- horovod/common/rendezvous/http_rendezvous.h
- horovod/common/rendezvous/http_rendezvous.cc
2 changes: 2 additions & 0 deletions horovod/common/basics.py
Expand Up @@ -44,6 +44,7 @@ def init(self, comm=None):
if not bool(mpi_enabled):
raise ValueError(
'Horovod MPI is not enabled; Please make sure it\'s installed and enabled.')

from mpi4py import MPI
if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int):
MPI_Comm = ctypes.c_int
Expand Down Expand Up @@ -126,6 +127,7 @@ def mpi_threads_supported(self):
if not bool(mpi_enabled):
raise ValueError(
'Horovod MPI is not enabled; Please make sure it\'s installed and enabled.')

mpi_threads_supported = self.MPI_LIB_CTYPES.horovod_mpi_threads_supported()
if mpi_threads_supported == -1:
raise ValueError(
Expand Down
46 changes: 24 additions & 22 deletions horovod/common/gloo_context.cc
Expand Up @@ -29,7 +29,7 @@ namespace common {

#if HAVE_MPI
void GlooContext::InitializeFromMPI(MPIContext& mpi_ctx,
const char* gloo_iface) {
const std::string& gloo_iface) {
if (!enabled_) {
return;
}
Expand Down Expand Up @@ -67,7 +67,7 @@ void GlooContext::Finalize() {
local_ctx.reset();
}

void GlooContext::Initialize(const char* gloo_iface) {
void GlooContext::Initialize(const std::string& gloo_iface) {
if (!enabled_) {
return;
}
Expand All @@ -79,23 +79,25 @@ void GlooContext::Initialize(const char* gloo_iface) {
attr.ai_family = AF_UNSPEC;
auto dev = gloo::transport::tcp::CreateDevice(attr);

auto rendezvous_server_addr = std::getenv("HOROVOD_RENDEZVOUS_ADDR");
const std::string rendezvous_server_addr = std::getenv
(HOROVOD_GLOO_RENDEZVOUS_ADDR);
auto rendezvous_server_port =
std::atoi(std::getenv("HOROVOD_RENDEZVOUS_PORT"));
std::strtol(std::getenv(HOROVOD_GLOO_RENDEZVOUS_PORT), nullptr, 10);

LOG(DEBUG) << "Rendezvous server addr " << rendezvous_server_addr;
LOG(DEBUG) << "Rendezvous server address " << rendezvous_server_addr;

// Get rendezvous info from env
int rank = std::atoi(getenv("HOROVOD_RANK"));
int size = std::atoi(getenv("HOROVOD_SIZE"));
int local_rank = std::atoi(getenv("HOROVOD_LOCAL_RANK"));
int local_size = std::atoi(getenv("HOROVOD_LOCAL_SIZE"));
int cross_rank = std::atoi(getenv("HOROVOD_CROSS_RANK"));
int cross_size = std::atoi(getenv("HOROVOD_CROSS_SIZE"));
int rank = std::strtol(getenv(HOROVOD_RANK), nullptr, 10);
int size = std::strtol(getenv(HOROVOD_SIZE), nullptr, 10);
int local_rank = std::strtol(getenv(HOROVOD_LOCAL_RANK), nullptr, 10);
int local_size = std::strtol(getenv(HOROVOD_LOCAL_SIZE), nullptr, 10);
int cross_rank = std::strtol(getenv(HOROVOD_CROSS_RANK), nullptr, 10);
int cross_size = std::strtol(getenv(HOROVOD_CROSS_SIZE), nullptr, 10);

// Global rendezvous
auto rendezvous =
HTTPStore(rendezvous_server_addr, rendezvous_server_port, "global", rank);
const std::string global_scope(HOROVOD_GLOO_GLOBAL_PREFIX);
auto rendezvous = HTTPStore(rendezvous_server_addr, rendezvous_server_port,
global_scope, rank);
LOG(DEBUG) << "Global Rendezvous started for rank " << rank
<< ", total size of " << size;
auto context = std::make_shared<gloo::rendezvous::Context>(rank, size);
Expand All @@ -104,10 +106,10 @@ void GlooContext::Initialize(const char* gloo_iface) {
LOG(DEBUG) << "Global Gloo context initialized.";

// Local rendezvous
std::string local_scope = std::string("local") + std::to_string(cross_rank);
auto local_rendezvous =
HTTPStore(rendezvous_server_addr, rendezvous_server_port,
local_scope.c_str(), rank);
const std::string local_scope =
std::string(HOROVOD_GLOO_LOCAL_PREFIX) + std::to_string(cross_rank);
auto local_rendezvous = HTTPStore(rendezvous_server_addr,
rendezvous_server_port, local_scope, rank);
LOG(DEBUG) << "Local Rendezvous started for rank " << rank
<< ", total size of " << local_size;
auto local_context =
Expand All @@ -117,11 +119,11 @@ void GlooContext::Initialize(const char* gloo_iface) {
LOG(DEBUG) << "Local Gloo context initialized.";

// Cross rendezvous
std::string cross_scope = std::string("cross") + std::to_string(local_rank);
auto cross_rendezvous =
HTTPStore(rendezvous_server_addr, rendezvous_server_port,
cross_scope.c_str(), rank);
LOG(DEBUG) << "Cross Rendezvous started for rank " << rank
const std::string cross_scope =
std::string(HOROVOD_GLOO_CROSS_PREFIX) + std::to_string(local_rank);
auto cross_rendezvous = HTTPStore(rendezvous_server_addr,
rendezvous_server_port, cross_scope, rank);
LOG(DEBUG) << "Cross-node Rendezvous started for rank " << rank
<< ", total size of " << size;
auto cross_context =
std::make_shared<gloo::rendezvous::Context>(cross_rank, cross_size);
Expand Down
17 changes: 15 additions & 2 deletions horovod/common/gloo_context.h
Expand Up @@ -26,13 +26,26 @@
namespace horovod {
namespace common {

// Horovod Gloo rendezvous knobs.
#define HOROVOD_GLOO_RENDEZVOUS_ADDR "HOROVOD_GLOO_RENDEZVOUS_ADDR"
#define HOROVOD_GLOO_RENDEZVOUS_PORT "HOROVOD_GLOO_RENDEZVOUS_PORT"
#define HOROVOD_GLOO_GLOBAL_PREFIX "global"
#define HOROVOD_GLOO_LOCAL_PREFIX "local"
#define HOROVOD_GLOO_CROSS_PREFIX "cross"
#define HOROVOD_RANK "HOROVOD_RANK"
#define HOROVOD_SIZE "HOROVOD_SIZE"
#define HOROVOD_LOCAL_RANK "HOROVOD_LOCAL_RANK"
#define HOROVOD_LOCAL_SIZE "HOROVOD_LOCAL_SIZE"
#define HOROVOD_CROSS_RANK "HOROVOD_CROSS_RANK"
#define HOROVOD_CROSS_SIZE "HOROVOD_CROSS_SIZE"

struct GlooContext {

#if HAVE_MPI
void InitializeFromMPI(MPIContext& mpi_ctx, const char* gloo_iface);
void InitializeFromMPI(MPIContext& mpi_ctx, const std::string& gloo_iface);
#endif

void Initialize(const char* gloo_iface);
void Initialize(const std::string& gloo_iface);

void Finalize();

Expand Down
47 changes: 33 additions & 14 deletions horovod/common/mpi_controller.cc
Expand Up @@ -88,7 +88,8 @@ void MPIController::CrossRankBitwiseAnd(std::vector<long long>& bitvector,
int ret_code = MPI_Allreduce(MPI_IN_PLACE, bitvector.data(), count,
MPI_LONG_LONG_INT, MPI_BAND, mpi_ctx_.mpi_comm);
if (ret_code != MPI_SUCCESS) {
throw std::logic_error("MPI_AllReduce failed, see MPI output for details.");
throw std::runtime_error(
"MPI_AllReduce failed, see MPI output for details.");
}
}

Expand All @@ -97,7 +98,8 @@ void MPIController::CrossRankBitwiseOr(std::vector<long long>& bitvector,
int ret_code = MPI_Allreduce(MPI_IN_PLACE, bitvector.data(), count,
MPI_LONG_LONG_INT, MPI_BOR, mpi_ctx_.mpi_comm);
if (ret_code != MPI_SUCCESS) {
throw std::logic_error("MPI_AllReduce failed, see MPI output for details.");
throw std::runtime_error(
"MPI_AllReduce failed, see MPI output for details.");
}
}

Expand Down Expand Up @@ -161,38 +163,55 @@ void MPIController::SendReadyTensors(RequestList& message_list) {
std::string encoded_message;
RequestList::SerializeToString(message_list, encoded_message);
int encoded_message_length = (int)encoded_message.length() + 1;
MPI_Gather(&encoded_message_length, 1, MPI_INT, nullptr, 1, MPI_INT,
RANK_ZERO, mpi_ctx_.mpi_comm);
MPI_Gatherv((void*)encoded_message.c_str(), encoded_message_length, MPI_BYTE,
nullptr, nullptr, nullptr, MPI_BYTE, RANK_ZERO,
mpi_ctx_.mpi_comm);
int ret_code = MPI_Gather(&encoded_message_length, 1, MPI_INT, nullptr, 1,
MPI_INT, RANK_ZERO, mpi_ctx_.mpi_comm);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error("MPI_Gather failed, see MPI output for details.");
}

ret_code = MPI_Gatherv((void*)encoded_message.c_str(), encoded_message_length,
MPI_BYTE, nullptr, nullptr, nullptr, MPI_BYTE,
RANK_ZERO, mpi_ctx_.mpi_comm);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error("MPI_Gather failed, see MPI output for details.");
}
}

void MPIController::RecvFinalTensors(ResponseList& response_list) {
int msg_length;
MPI_Bcast(&msg_length, 1, MPI_INT, RANK_ZERO, mpi_ctx_.mpi_comm);
int ret_code =
MPI_Bcast(&msg_length, 1, MPI_INT, RANK_ZERO, mpi_ctx_.mpi_comm);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Broadcast failed, see MPI output for details.");
}

auto buffer = new uint8_t[msg_length];
MPI_Bcast(buffer, msg_length, MPI_BYTE, RANK_ZERO, mpi_ctx_.mpi_comm);
ret_code =
MPI_Bcast(buffer, msg_length, MPI_BYTE, RANK_ZERO, mpi_ctx_.mpi_comm);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Broadcast failed, see MPI output for details.");
}
ResponseList::ParseFromBytes(response_list, buffer);
delete[] buffer;
}

void MPIController::Bcast(void* buffer, size_t size, int root_rank, Communicator communicator) {
void MPIController::Bcast(void* buffer, size_t size, int root_rank,
Communicator communicator) {
MPI_Comm comm = mpi_ctx_.GetMPICommunicator(communicator);
int ret_code = MPI_Bcast(buffer, size, MPI_BYTE, root_rank, comm);

if (ret_code != MPI_SUCCESS) {
throw std::logic_error("MPI_Broadcast failed, see MPI output for details.");
throw std::runtime_error(
"MPI_Broadcast failed, see MPI output for details.");
}
}

void MPIController::Barrier(Communicator communicator) {
MPI_Comm comm = mpi_ctx_.GetMPICommunicator(communicator);
int ret_code = MPI_Barrier(comm);

if (ret_code != MPI_SUCCESS) {
throw std::logic_error("MPI_Barrier failed, see MPI output for details.");
throw std::runtime_error("MPI_Barrier failed, see MPI output for details.");
}
}

Expand Down
2 changes: 1 addition & 1 deletion horovod/common/ops/mpi_cuda_operations.cc
Expand Up @@ -60,7 +60,7 @@ Status MPI_CUDAAllreduce::Execute(std::vector<TensorTableEntry>& entries, const
mpi_context_->GetMPISumOp(first_entry.tensor->dtype()),
mpi_context_->GetMPICommunicator(Communicator::GLOBAL));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Allreduce failed, see MPI output for details.");
throw std::runtime_error("MPI_Allreduce failed, see MPI output for details.");
}
timeline.ActivityEndAll(entries);

Expand Down
10 changes: 5 additions & 5 deletions horovod/common/ops/mpi_operations.cc
Expand Up @@ -51,7 +51,7 @@ Status MPIAllreduce::Execute(std::vector<TensorTableEntry>& entries, const Respo
mpi_context_->GetMPISumOp(first_entry.tensor->dtype()),
mpi_context_->GetMPICommunicator(Communicator::GLOBAL));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Allreduce failed, see MPI output for details.");
throw std::runtime_error("MPI_Allreduce failed, see MPI output for details.");
}
timeline.ActivityEndAll(entries);

Expand Down Expand Up @@ -137,7 +137,7 @@ Status MPIAllgather::Execute(std::vector<TensorTableEntry>& entries, const Respo
dtype,
mpi_context_->GetMPICommunicator(Communicator::GLOBAL));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Allgatherv failed, see MPI output for details.");
throw std::runtime_error("MPI_Allgatherv failed, see MPI output for details.");
}
global_state_->timeline.ActivityEndAll(entries);

Expand Down Expand Up @@ -287,7 +287,7 @@ Status MPIHierarchicalAllgather::Execute(std::vector<TensorTableEntry>& entries,
mpi_context_->GetMPIDataType(first_entry.tensor->dtype()),
mpi_context_->GetMPICommunicator(Communicator::CROSS));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Allgatherv failed, see MPI output for details.");
throw std::runtime_error("MPI_Allgatherv failed, see MPI output for details.");
}
}
Barrier();
Expand Down Expand Up @@ -316,7 +316,7 @@ bool MPIHierarchicalAllgather::Enabled(const ParameterManager& param_manager,
void MPIHierarchicalAllgather::Barrier() {
int op = MPI_Barrier(mpi_context_->GetMPICommunicator(Communicator::GLOBAL));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Barrier failed, see MPI output for details.");
throw std::runtime_error("MPI_Barrier failed, see MPI output for details.");
}
}

Expand All @@ -342,7 +342,7 @@ Status MPIBroadcast::Execute(std::vector<TensorTableEntry>& entries, const Respo
e.root_rank,
mpi_context_->GetMPICommunicator(Communicator::GLOBAL));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Broadcast failed, see MPI output for details.");
throw std::runtime_error("MPI_Broadcast failed, see MPI output for details.");
}
global_state_->timeline.ActivityEndAll(entries);

Expand Down
2 changes: 1 addition & 1 deletion horovod/common/ops/nccl_operations.cc
Expand Up @@ -300,7 +300,7 @@ NCCLHierarchicalAllreduce::Execute(std::vector<TensorTableEntry>& entries,
mpi_context_->GetMPISumOp(first_entry.tensor->dtype()),
mpi_context_->GetMPICommunicator(Communicator::CROSS));
if (op != MPI_SUCCESS) {
throw std::logic_error("MPI_Allreduce failed, see MPI output for details.");
throw std::runtime_error("MPI_Allreduce failed, see MPI output for details.");
}
timeline.ActivityEndAll(entries);

Expand Down