Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rabit support of large cluster #68

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ before_install:
- source ${TRAVIS}/travis_setup_env.sh

install:
- pip install --user cpplint pylint
- pip install --user cpplint pylint kubernetes urllib3

script: scripts/travis_script.sh

Expand Down
62 changes: 31 additions & 31 deletions src/allreduce_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ utils::TCPSocket AllreduceBase::ConnectTracker(void) const {
#ifdef _MSC_VER
Sleep(1);
#else
sleep(1);
sleep(retry);
#endif
continue;
}
Expand Down Expand Up @@ -454,47 +454,47 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == parent_index) {
if (size_down_in != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
if (size_up_out != total_size && size_up_out < size_up_reduce) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
} else {
if (links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
// size_write <= size_read
if (links[i].size_write != total_size) {
if (links[i].size_write < size_down_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
// only watch for exception in live channels
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
finished = false;
}
}
}
// finish runing allreduce
if (finished) break;
// select must return
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index && selecter.CheckRead(links[i].sock)) {
if (i != parent_index && watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToRingBuffer(size_up_out, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
Expand Down Expand Up @@ -551,7 +551,7 @@ AllreduceBase::TryAllreduceTree(void *sendrecvbuf_,
}
}
// read data from parent
if (selecter.CheckRead(links[parent_index].sock) &&
if (watcher.CheckRead(links[parent_index].sock) &&
total_size > size_down_in) {
ssize_t len = links[parent_index].sock.
Recv(sendrecvbuf + size_down_in, total_size - size_down_in);
Expand Down Expand Up @@ -620,37 +620,37 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
while (true) {
bool finished = true;
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (in_link == -2) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (i == in_link && links[i].size_read != total_size) {
selecter.WatchRead(links[i].sock); finished = false;
watcher.WatchRead(links[i].sock); finished = false;
}
if (in_link != -2 && i != in_link && links[i].size_write != total_size) {
if (links[i].size_write < size_in) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
// finish running
if (finished) break;
// select
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (in_link == -2) {
// probe in-link
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[i], ret);
Expand All @@ -663,7 +663,7 @@ AllreduceBase::TryBroadcast(void *sendrecvbuf_, size_t total_size, int root) {
}
} else {
// read from in link
if (in_link >= 0 && selecter.CheckRead(links[in_link].sock)) {
if (in_link >= 0 && watcher.CheckRead(links[in_link].sock)) {
ReturnType ret = links[in_link].ReadToArray(sendrecvbuf_, total_size);
if (ret != kSuccess) {
return ReportError(&links[in_link], ret);
Expand Down Expand Up @@ -717,20 +717,20 @@ AllreduceBase::TryAllgatherRing(void *sendrecvbuf_, size_t total_size,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < read_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
size_t size = stop_read - read_ptr;
size_t start = read_ptr % total_size;
if (start + size > total_size) {
Expand Down Expand Up @@ -811,20 +811,20 @@ AllreduceBase::TryReduceScatterRing(void *sendrecvbuf_,
while (true) {
// select helper
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != stop_read) {
selecter.WatchRead(next.sock);
watcher.WatchRead(next.sock);
finished = false;
}
if (write_ptr != stop_write) {
if (write_ptr < reduce_ptr) {
selecter.WatchWrite(prev.sock);
watcher.WatchWrite(prev.sock);
}
finished = false;
}
if (finished) break;
selecter.Select();
if (read_ptr != stop_read && selecter.CheckRead(next.sock)) {
watcher.Poll();
if (read_ptr != stop_read && watcher.CheckRead(next.sock)) {
ReturnType ret = next.ReadToRingBuffer(reduce_ptr, stop_read);
if (ret != kSuccess) {
return ReportError(&next, ret);
Expand Down
18 changes: 9 additions & 9 deletions src/allreduce_robust-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,29 +70,29 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
utils::Assert(stage != 2 && stage != 1, "invalie stage id");
}
// select helper
utils::SelectHelper selecter;
utils::PollHelper watcher;
bool done = (stage == 3);
for (int i = 0; i < nlink; ++i) {
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
switch (stage) {
case 0:
if (i != parent_index && links[i].size_read != sizeof(EdgeType)) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 1:
if (i == parent_index) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
break;
case 2:
if (i == parent_index) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
}
break;
case 3:
if (i != parent_index && links[i].size_write != sizeof(EdgeType)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
done = false;
}
break;
Expand All @@ -101,11 +101,11 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
}
// finish all the stages, and write out message
if (done) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
// recive OOB message from some link
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
Expand All @@ -114,7 +114,7 @@ AllreduceRobust::MsgPassing(const NodeType &node_value,
// read data from childs
for (int i = 0; i < nlink; ++i) {
if (i != parent_index) {
if (selecter.CheckRead(links[i].sock)) {
if (watcher.CheckRead(links[i].sock)) {
ReturnType ret = links[i].ReadToArray(&edge_in[i], sizeof(EdgeType));
if (ret != kSuccess) return ReportError(&links[i], ret);
}
Expand Down
44 changes: 22 additions & 22 deletions src/allreduce_robust.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
if (len == sizeof(sig)) all_links[i].size_write = 2;
}
}
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_write != 2 && !all_links[i].sock.BadSocket()) {
Expand All @@ -343,23 +343,23 @@ AllreduceRobust::ReturnType AllreduceRobust::TryResetLinks(void) {
}
if (finished) break;
// wait to read from the channels to discard data
rsel.Select();
rsel.Poll();
}
for (int i = 0; i < nlink; ++i) {
if (!all_links[i].sock.BadSocket()) {
utils::SelectHelper::WaitExcept(all_links[i].sock);
utils::PollHelper::WaitExcept(all_links[i].sock);
}
}
while (true) {
utils::SelectHelper rsel;
utils::PollHelper rsel;
bool finished = true;
for (int i = 0; i < nlink; ++i) {
if (all_links[i].size_read == 0 && !all_links[i].sock.BadSocket()) {
rsel.WatchRead(all_links[i].sock); finished = false;
}
}
if (finished) break;
rsel.Select();
rsel.Poll();
for (int i = 0; i < nlink; ++i) {
if (all_links[i].sock.BadSocket()) continue;
if (all_links[i].size_read == 0) {
Expand Down Expand Up @@ -624,32 +624,32 @@ AllreduceRobust::TryRecoverData(RecoverType role,
}
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
for (int i = 0; i < nlink; ++i) {
if (i == recv_link && links[i].size_read != size) {
selecter.WatchRead(links[i].sock);
watcher.WatchRead(links[i].sock);
finished = false;
}
if (req_in[i] && links[i].size_write != size) {
if (role == kHaveData ||
(links[recv_link].size_read != links[i].size_write)) {
selecter.WatchWrite(links[i].sock);
watcher.WatchWrite(links[i].sock);
}
finished = false;
}
selecter.WatchException(links[i].sock);
watcher.WatchException(links[i].sock);
}
if (finished) break;
selecter.Select();
watcher.Poll();
// exception handling
for (int i = 0; i < nlink; ++i) {
if (selecter.CheckExcept(links[i].sock)) {
if (watcher.CheckExcept(links[i].sock)) {
return ReportError(&links[i], kGetExcept);
}
}
if (role == kRequestData) {
const int pid = recv_link;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
ReturnType ret = links[pid].ReadToArray(sendrecvbuf_, size);
if (ret != kSuccess) {
return ReportError(&links[pid], ret);
Expand Down Expand Up @@ -677,7 +677,7 @@ AllreduceRobust::TryRecoverData(RecoverType role,
if (role == kPassData) {
const int pid = recv_link;
const size_t buffer_size = links[pid].buffer_size;
if (selecter.CheckRead(links[pid].sock)) {
if (watcher.CheckRead(links[pid].sock)) {
size_t min_write = size;
for (int i = 0; i < nlink; ++i) {
if (req_in[i]) min_write = std::min(links[i].size_write, min_write);
Expand Down Expand Up @@ -1144,22 +1144,22 @@ AllreduceRobust::RingPassing(void *sendrecvbuf_,
char *buf = reinterpret_cast<char*>(sendrecvbuf_);
while (true) {
bool finished = true;
utils::SelectHelper selecter;
utils::PollHelper watcher;
if (read_ptr != read_end) {
selecter.WatchRead(prev.sock);
watcher.WatchRead(prev.sock);
finished = false;
}
if (write_ptr < read_ptr && write_ptr != write_end) {
selecter.WatchWrite(next.sock);
watcher.WatchWrite(next.sock);
finished = false;
}
selecter.WatchException(prev.sock);
selecter.WatchException(next.sock);
watcher.WatchException(prev.sock);
watcher.WatchException(next.sock);
if (finished) break;
selecter.Select();
if (selecter.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (selecter.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && selecter.CheckRead(prev.sock)) {
watcher.Poll();
if (watcher.CheckExcept(prev.sock)) return ReportError(&prev, kGetExcept);
if (watcher.CheckExcept(next.sock)) return ReportError(&next, kGetExcept);
if (read_ptr != read_end && watcher.CheckRead(prev.sock)) {
ssize_t len = prev.sock.Recv(buf + read_ptr, read_end - read_ptr);
if (len == 0) {
prev.sock.Close(); return ReportError(&prev, kRecvZeroLen);
Expand Down
Loading