Skip to content

Commit

Permalink
update network for non powers of 2 workers (#1178)
Browse files Browse the repository at this point in the history
* update network.h

Improve training speed in parallel learning where workers is not powers of 2

* update linders_socket.cpp

Improve training speed in parallel learning where workers is not powers of 2

* update linder_topo.cpp

Improve training speed in parallel learning where workers is not powers of 2

* update network.cpp

Improve training speed in parallel learning where workers is not powers of 2

* update linder_topo.cpp

fix a bug
  • Loading branch information
qrqpjxq authored and guolinke committed Jan 5, 2018
1 parent 5afffa7 commit ee8a65a
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 31 deletions.
9 changes: 7 additions & 2 deletions include/LightGBM/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,14 @@ class BruckMap {
/*! \brief Network structure for recursive halving algorithm */
class RecursiveHalvingMap {
public:
bool need_pairwise;
/*! \brief If number workers is powers of 2 */
bool is_prof2;
/*! \brief Communication times for one recursize halving algorithm */
int k;
/*! \brief Number workers subtract powers of 2 */
int num_remain;
/*! \brief Virtual rank for recursize halving algorithm */
int virtual_rank;
/*! \brief ranks[i] means the machines that will communicate with on i-th communication*/
std::vector<int> ranks;
/*! \brief send_block_start[i] means send block start index at i-th communication*/
Expand All @@ -54,7 +59,7 @@ class RecursiveHalvingMap {

RecursiveHalvingMap();

RecursiveHalvingMap(int k, bool in_need_pairwise);
RecursiveHalvingMap(int k, int num_remain, int virtual_rank, bool is_prof2);

/*!
* \brief Create the object of recursive halving map
Expand Down
55 changes: 34 additions & 21 deletions src/network/linker_topo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,22 @@ BruckMap BruckMap::Construct(int rank, int num_machines) {

RecursiveHalvingMap::RecursiveHalvingMap() {
k = 0;
need_pairwise = true;
is_prof2 = true;
num_remain = 0;
}

RecursiveHalvingMap::RecursiveHalvingMap(int in_k, bool in_need_pairwise) {
need_pairwise = in_need_pairwise;
RecursiveHalvingMap::RecursiveHalvingMap(int in_k, int in_remain, int in_rank, bool is_power_of2) {
k = in_k;
if (!need_pairwise) {
for (int i = 0; i < k; ++i) {
// defalut set as -1
ranks.push_back(-1);
send_block_start.push_back(-1);
send_block_len.push_back(-1);
recv_block_start.push_back(-1);
recv_block_len.push_back(-1);
}
is_prof2 = is_power_of2;
num_remain = in_remain;
virtual_rank = in_rank;
for (int i = 0; i < k; ++i) {
// defalut set as -1
ranks.push_back(-1);
send_block_start.push_back(-1);
send_block_len.push_back(-1);
recv_block_start.push_back(-1);
recv_block_len.push_back(-1);
}
}

Expand All @@ -74,28 +75,40 @@ RecursiveHalvingMap RecursiveHalvingMap::Construct(int rank, int num_machines) {
distance.push_back(1 << (k - 1 - i));
}

if ((1 << k) == num_machines) {
RecursiveHalvingMap rec_map(k, false);
// if num_machines = 2^k, don't need to group machines
int remain = num_machines - (1 << k);
int virtual_rank = rank;
// if virtual_rank not -1 will not excute recursize halving algorithm
if (rank < 2 * remain) {
if (rank % 2 == 0) {
virtual_rank = -1;
} else {
virtual_rank = rank / 2;
}
} else {
virtual_rank = rank - remain;
}

bool is_power_of2 = false;
if ((1 << k) == num_machines) { is_power_of2 = true; }
RecursiveHalvingMap rec_map(k, remain, virtual_rank, is_power_of2);
if (virtual_rank != -1) {
for (int i = 0; i < k; ++i) {
// communication direction, %2 == 0 is positive
const int dir = ((rank / distance[i]) % 2 == 0) ? 1 : -1;
const int dir = ((virtual_rank / distance[i]) % 2 == 0) ? 1 : -1;
// neighbor at k-th communication
const int next_node_idx = rank + dir * distance[i];
const int next_node_idx = virtual_rank + dir * distance[i];
rec_map.ranks[i] = next_node_idx;
// receive data block at k-th communication
const int recv_block_start = rank / distance[i];
const int recv_block_start = virtual_rank / distance[i];
rec_map.recv_block_start[i] = recv_block_start * distance[i];
rec_map.recv_block_len[i] = distance[i];
// send data block at k-th communication
const int send_block_start = next_node_idx / distance[i];
rec_map.send_block_start[i] = send_block_start * distance[i];
rec_map.send_block_len[i] = distance[i];
}
return rec_map;
} else {
return RecursiveHalvingMap(k, true);
}
return rec_map;
}

} // namespace LightGBM
Expand Down
2 changes: 1 addition & 1 deletion src/network/linkers_socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void Linkers::ListenThread(int incoming_cnt) {
void Linkers::Construct() {
// save ranks that need to connect with
std::unordered_map<int, int> need_connect;
if (recursive_halving_map_.need_pairwise) {
if (!recursive_halving_map_.is_prof2) {
for (int i = 0; i < num_machines_; ++i) {
if (i != rank_) {
need_connect[i] = 1;
Expand Down
68 changes: 61 additions & 7 deletions src/network/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,19 +225,73 @@ void Network::AllgatherRing(char* input, const comm_size_t* block_start, const c
}
}

void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
void Network::ReduceScatter(char* input, comm_size_t input_size, int type_size, const comm_size_t* block_start,
const comm_size_t* block_len, char* output, comm_size_t output_size, const ReduceFunction& reducer) {
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (reduce_scatter_ext_fun_ != nullptr) {
return reduce_scatter_ext_fun_(input, input_size, type_size, block_start, block_len, num_machines_, output, output_size, reducer);
}
if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) {
int out_rank = (rank_ + i) % num_machines_;
int in_rank = (rank_ - i + num_machines_) % num_machines_;
linkers_->SendRecv(out_rank, input + block_start[out_rank], block_len[out_rank], in_rank, output, block_len[rank_]);
reducer(output, input + block_start[rank_], type_size, block_len[rank_]);
if (!recursive_halving_map_.is_prof2) {
int remain = recursive_halving_map_.num_remain;
std::vector<int> rcsv_block_start(1 << recursive_halving_map_.k);
std::vector<int> rcsv_block_len(1 << recursive_halving_map_.k);
std::vector<int> real_ranks;
int brush = 0;
// build block_start and block_len for remain powers of 2 workers
for (int i = 0; i < num_machines_; ++i) {
if ((i < 2 * remain) && (i % 2 != 0)) {
real_ranks.push_back(i);
rcsv_block_start[i - 1 - brush] = block_start[i - 1];
rcsv_block_len[i - 1 - brush] = block_len[i] + block_len[i - 1];
brush++;
}
if (i >= 2 * remain) {
real_ranks.push_back(i);
rcsv_block_start[i - remain] = block_start[i];
rcsv_block_len[i - remain] = block_len[i];
}
}
// if local rank is remain, send local data to rank+1
if (rank_ < 2 * remain) {
if (rank_ % 2 == 0) {
linkers_->Send(rank_ + 1, input, input_size);
} else {
linkers_->Recv(rank_ - 1, output, input_size);
reducer(output, input, type_size, input_size);
}
}
// excute recursize halving algorithm for powers of 2 workers
if (recursive_halving_map_.virtual_rank != -1) {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
int virtual_rank = recursive_halving_map_.ranks[i];
int target = real_ranks[virtual_rank];
int send_block_start = recursive_halving_map_.send_block_start[i];
int recv_block_start = recursive_halving_map_.recv_block_start[i];
// get send information
int send_size = 0;
for (int j = 0; j < recursive_halving_map_.send_block_len[i]; ++j) {
send_size += rcsv_block_len[send_block_start + j];
}
// get recv information
int need_recv_cnt = 0;
for (int j = 0; j < recursive_halving_map_.recv_block_len[i]; ++j) {
need_recv_cnt += rcsv_block_len[recv_block_start + j];
}
// send and recv at same time
linkers_->SendRecv(target, input + rcsv_block_start[send_block_start], send_size, target, output, need_recv_cnt);
// reduce
reducer(output, input + rcsv_block_start[recv_block_start], type_size, need_recv_cnt);
}
}
// send result back to remain workers
if (rank_ < 2 * remain) {
if (rank_ % 2 != 0) {
linkers_->Send(rank_ - 1, input + block_start[rank_ - 1], block_len[rank_ - 1]);
} else {
linkers_->Recv(rank_ + 1, input + block_start[rank_], block_len[rank_]);
}
}
} else {
for (int i = 0; i < recursive_halving_map_.k; ++i) {
Expand Down

0 comments on commit ee8a65a

Please sign in to comment.