Skip to content

Commit

Permalink
Refine allgather (#1175)
Browse files Browse the repository at this point in the history
* refine allgather.

* fix a bug.
  • Loading branch information
guolinke committed Jan 5, 2018
1 parent 7d35bee commit 5afffa7
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
7 changes: 7 additions & 0 deletions include/LightGBM/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,13 @@ class Network {
*/
static void Allgather(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

static void AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);

static void AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);


/*!
* \brief Perform reduce scatter by using recursive halving algorithm.
Communication times is O(log(n)), and communication cost is O(input_size)
Expand Down
57 changes: 57 additions & 0 deletions src/network/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,19 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
if (allgather_ext_fun_ != nullptr) {
return allgather_ext_fun_(input, block_len[rank_], block_start, block_len, num_machines_, output, all_size);
}
const comm_size_t kRecursiveDoublingThreshold = 1024 * 1024; // 1MB
const comm_size_t kBruckThreshold = 512 * 1024; // 512KB
const bool is_power_of2 = (num_machines_ & (num_machines_ - 1)) == 0;
if (is_power_of2 && all_size < kRecursiveDoublingThreshold) {
AllgatherRecursiveDoubling(input, block_start, block_len, output, all_size);
} else if (all_size < kBruckThreshold) {
AllgatherBruck(input, block_start, block_len, output, all_size);
} else {
AllgatherRing(input, block_start, block_len, output, all_size);
}
}

void Network::AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size) {
comm_size_t write_pos = 0;
// use output as receive buffer
std::memcpy(output, input, block_len[rank_]);
Expand Down Expand Up @@ -168,6 +181,50 @@ void Network::Allgather(char* input, const comm_size_t* block_start, const comm_
std::reverse<char*>(output + block_start[rank_], output + all_size);
}

void Network::AllgatherRecursiveDoubling(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) {
// use output as receive buffer
std::memcpy(output + block_start[rank_], input, block_len[rank_]);
for (int i = 0; i < bruck_map_.k; ++i) {
// get current local block size
int cur_step = 1 << i;
const int vgroup = rank_ / cur_step;
const int vrank = vgroup * cur_step;
int target = rank_ + cur_step;
int target_vrank = (vgroup + 1) * cur_step;
if (vgroup & 1) {
target = rank_ - cur_step;
target_vrank = (vgroup - 1) * cur_step;
}
// get send information
comm_size_t need_send_len = 0;
// get recv information
comm_size_t need_recv_len = 0;
for (int j = 0; j < cur_step; ++j) {
need_send_len += block_len[(vrank + j)];
need_recv_len += block_len[(target_vrank + j)];
}
// send and recv at same time
linkers_->SendRecv(target, output + block_start[vrank], need_send_len,
target, output + block_start[target_vrank], need_recv_len);
}
}

void Network::AllgatherRing(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t) {
// use output as receive buffer
std::memcpy(output + block_start[rank_], input, block_len[rank_]);
int out_rank = (rank_ + 1) % num_machines_;
int in_rank = (rank_ - 1 + num_machines_) % num_machines_;
int out_place = rank_;
int in_place = in_rank;
for (int i = 1; i < num_machines_; ++i) {
// send and recv at same time
linkers_->SendRecv(out_rank, output + block_start[out_place], block_len[out_place],
in_rank, output + block_start[in_place], block_len[in_place]);
out_place = (out_place - 1 + num_machines_) % num_machines_;
in_place = (in_place - 1 + num_machines_) % num_machines_;
}
}

void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
Expand Down

0 comments on commit 5afffa7

Please sign in to comment.