diff --git a/byteps/common/communicator.cc b/byteps/common/communicator.cc index ae21789a5..f1ae6aa2c 100644 --- a/byteps/common/communicator.cc +++ b/byteps/common/communicator.cc @@ -90,8 +90,13 @@ void BytePSCommSocket::init(int* rank, int* size, int* local_rank, int* local_si *my_role = (_local_rank == _root) ? LOCAL_ROOT : LOCAL_WORKER; bool is_root = (*my_role == LOCAL_ROOT) ? true : false; - _send_path = std::string(BASE_SOCKET_PATH_SEND); - _recv_path = std::string(BASE_SOCKET_PATH_RECV); + if (getenv("BYTEPS_SOCKET_PATH")) { + _send_path = std::string(getenv("BYTEPS_SOCKET_PATH")) + std::string("/socket_send_"); + _recv_path = std::string(getenv("BYTEPS_SOCKET_PATH")) + std::string("/socket_recv_"); + } else { + _send_path = std::string(DEFAULT_BASE_SOCKET_PATH_SEND); + _recv_path = std::string(DEFAULT_BASE_SOCKET_PATH_RECV); + } _send_fd = initSocket(_local_rank, _send_path); _recv_fd = initSocket(_local_rank, _recv_path); diff --git a/byteps/common/communicator.h b/byteps/common/communicator.h index ced73132c..252c9f0d0 100644 --- a/byteps/common/communicator.h +++ b/byteps/common/communicator.h @@ -31,8 +31,8 @@ #include #include "logging.h" -#define BASE_SOCKET_PATH_RECV "/usr/local/socket_recv_" -#define BASE_SOCKET_PATH_SEND "/usr/local/socket_send_" +#define DEFAULT_BASE_SOCKET_PATH_RECV "/tmp/socket_recv_" +#define DEFAULT_BASE_SOCKET_PATH_SEND "/tmp/socket_send_" #define MAX_LINE 8000 namespace byteps { diff --git a/byteps/common/global.cc b/byteps/common/global.cc index 8bdba7c4a..a3d39f248 100644 --- a/byteps/common/global.cc +++ b/byteps/common/global.cc @@ -98,7 +98,6 @@ void BytePSGlobal::Init() { _partition_bytes = AlignTo(_partition_bytes, (8 * _local_size)); BPS_CHECK(getenv("DMLC_NUM_WORKER")) << "error: env DMLC_NUM_WORKER not set"; - BPS_CHECK(getenv("DMLC_NUM_SERVER")) << "error: env DMLC_NUM_SERVER not set"; _num_worker = atoi(getenv("DMLC_NUM_WORKER")); @@ -107,6 +106,10 @@ void BytePSGlobal::Init() { } _is_distributed_job = (_num_worker>1) ? true : _is_distributed_job; + if (_is_distributed_job) { + BPS_CHECK(getenv("DMLC_NUM_SERVER")) << "error: launch distributed job, but env DMLC_NUM_SERVER not set"; + } + BPS_LOG(DEBUG) << "Number of worker=" << _num_worker << ", launching " << (IsDistributed() ? "" : "non-") << "distributed job";