Skip to content

Commit

Permalink
allreduce_robust.cc: Allow num_global_replica to be 0 (#38)
Browse files Browse the repository at this point in the history
In some cases, users may not want to have any global replica of
the data being broadcasted/all-reduced. In such cases, set the
result_buffer_round to -1 as a flag that this is not necessary
and check for it.
  • Loading branch information
AbdealiLoKo authored and tqchen committed Nov 24, 2016
1 parent 032152a commit 21b5e12
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/allreduce_robust.cc
Expand Up @@ -33,7 +33,11 @@ AllreduceRobust::AllreduceRobust(void) {
}
void AllreduceRobust::Init(int argc, char* argv[]) {
AllreduceBase::Init(argc, argv);
result_buffer_round = std::max(world_size / num_global_replica, 1);
if (num_global_replica == 0) {
result_buffer_round = -1;
} else {
result_buffer_round = std::max(world_size / num_global_replica, 1);
}
}
/*! \brief shutdown the engine */
void AllreduceRobust::Shutdown(void) {
Expand Down Expand Up @@ -86,7 +90,8 @@ void AllreduceRobust::Allreduce(void *sendrecvbuf_,
bool recovered = RecoverExec(sendrecvbuf_, type_nbytes * count, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
(result_buffer_round == -1 ||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
if (!recovered && prepare_fun != NULL) prepare_fun(prepare_arg);
Expand Down Expand Up @@ -118,7 +123,8 @@ void AllreduceRobust::Broadcast(void *sendrecvbuf_, size_t total_size, int root)
bool recovered = RecoverExec(sendrecvbuf_, total_size, 0, seq_counter);
// now we are free to remove the last result, if any
if (resbuf.LastSeqNo() != -1 &&
(resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
(result_buffer_round == -1 ||
resbuf.LastSeqNo() % result_buffer_round != rank % result_buffer_round)) {
resbuf.DropLast();
}
void *temp = resbuf.AllocTemp(1, total_size);
Expand Down

0 comments on commit 21b5e12

Please sign in to comment.