Skip to content

Commit

Permalink
fix bug in LGBM_NetworkInitWithFunctions
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Dec 15, 2017
1 parent 159e9a1 commit 72b5495
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 34 deletions.
6 changes: 3 additions & 3 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,9 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();

LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr,
void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr,
int num_machines,
int rank);

Expand Down
4 changes: 1 addition & 3 deletions include/LightGBM/meta.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ const double kZeroThreshold = 1e-35f;

using ReduceFunction = std::function<void(const char*, char*, int)>;

typedef void(*ReduceFunctionInC)(const char*, char*, int);

using PredictFunction =
std::function<void(const std::vector<std::pair<int, double>>&, double* output)>;

using AllreduceFunction = std::function<void(char*, int, int, char*, const ReduceFunction&)>;

using ReduceScatterFunction = std::function<void(char*, int, const int*, const int*, char*, const ReduceFunction&)>;

using AllgatherFunction = std::function<void(char*, int, char*)>;
using AllgatherFunction = std::function<void(char*, int, const int*, const int*, char*)>;

#define NO_SPECIFIC (-1)

Expand Down
12 changes: 6 additions & 6 deletions include/LightGBM/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ class Network {
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduceFunction(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatterFunction(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgatherFunction(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }
static void SetAllReduceFunction(AllreduceFunction allreduce_ext_fun) { allreduce_ext_fun_ = allreduce_ext_fun;}
static void SetReduceScatterFunction(ReduceScatterFunction reduce_scatter_ext_fun) { reduce_scatter_ext_fun_ = reduce_scatter_ext_fun; }
static void SetAllgatherFunction(AllgatherFunction allgather_ext_fun) { allgather_ext_fun_ = allgather_ext_fun; }

private:
/*! \brief Number of all machines */
Expand All @@ -215,9 +215,9 @@ class Network {
/*! \brief Size of buffer_ */
static THREAD_LOCAL int buffer_size_;
/*! \brief Funcs*/
static THREAD_LOCAL AllreduceFunction AllreduceFuncPtr_;
static THREAD_LOCAL ReduceScatterFunction ReduceScatterFuncPtr_;
static THREAD_LOCAL AllgatherFunction AllgatherFuncPtr_;
static THREAD_LOCAL AllreduceFunction allreduce_ext_fun_;
static THREAD_LOCAL ReduceScatterFunction reduce_scatter_ext_fun_;
static THREAD_LOCAL AllgatherFunction allgather_ext_fun_;
};

inline int Network::rank() {
Expand Down
25 changes: 13 additions & 12 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1220,26 +1220,27 @@ int LGBM_NetworkFree() {
API_END();
}

int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int LGBM_NetworkInitWithFunctions(void* allreduce_fun_ptr,
void* reduce_scatter_fun_ptr,
void* allgather_fun_ptr,
int num_machines,
int rank) {
API_BEGIN();
typedef void(*ReduceFunctionPtr)(const char* input, char* output, int array_size);
if (num_machines > 1) {
auto allreduce_fun = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr;
return tmp(arg1, arg2, arg3, arg4, ptr);
auto allreduce_fun = [allreduce_fun_ptr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& reduce_fun) {
auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionPtr&))allreduce_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, reduce_fun_ptr);
};
Network::SetAllReduceFunction(allreduce_fun);
auto reduce_scatter_fun = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr;
return tmp(arg1, arg2, arg3, arg4, arg5, ptr);
auto reduce_scatter_fun = [reduce_scatter_fun_ptr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& reduce_fun) {
auto reduce_fun_ptr = *reduce_fun.target<ReduceFunctionPtr>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionPtr&))reduce_scatter_fun_ptr;
return tmp(arg1, arg2, arg3, arg4, arg5, reduce_fun_ptr);
};
Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgatherFunction((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetAllgatherFunction((void(*)(char*, int, const int*, const int*, char*))allgather_fun_ptr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
}
Expand Down
20 changes: 10 additions & 10 deletions src/network/network.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ THREAD_LOCAL std::vector<int> Network::block_start_;
THREAD_LOCAL std::vector<int> Network::block_len_;
THREAD_LOCAL int Network::buffer_size_;
THREAD_LOCAL std::vector<char> Network::buffer_;
THREAD_LOCAL AllreduceFunction Network::AllreduceFuncPtr_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::ReduceScatterFuncPtr_ = NULL;
THREAD_LOCAL AllgatherFunction Network::AllgatherFuncPtr_ = NULL;
THREAD_LOCAL AllreduceFunction Network::allreduce_ext_fun_ = NULL;
THREAD_LOCAL ReduceScatterFunction Network::reduce_scatter_ext_fun_ = NULL;
THREAD_LOCAL AllgatherFunction Network::allgather_ext_fun_ = NULL;


void Network::Init(NetworkConfig config) {
Expand Down Expand Up @@ -49,8 +49,8 @@ void Network::Allreduce(char* input, int input_size, int type_size, char* output
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (AllreduceFuncPtr_ != NULL) {
return AllreduceFuncPtr_(input, input_size, type_size, output, reducer);
if (allreduce_ext_fun_ != NULL) {
return allreduce_ext_fun_(input, input_size, type_size, output, reducer);
}
int count = input_size / type_size;
// if small package or small count , do it by all gather.(reduce the communication times.)
Expand Down Expand Up @@ -106,9 +106,6 @@ void Network::Allgather(char* input, int send_size, char* output) {
Log::Fatal("Please initilize the network interface first");
}
if (num_machines_ <= 1) { return; }
if (AllgatherFuncPtr_ != NULL) {
return AllgatherFuncPtr_(input, send_size, output);
}
// assign blocks
block_start_[0] = 0;
block_len_[0] = send_size;
Expand All @@ -124,6 +121,9 @@ void Network::Allgather(char* input, int all_size, const int* block_start, const
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (allgather_ext_fun_ != NULL) {
return allgather_ext_fun_(input, all_size, block_start, block_len, output);
}
int write_pos = 0;
// use output as receive buffer
std::memcpy(output, input, block_len[rank_]);
Expand Down Expand Up @@ -159,8 +159,8 @@ void Network::ReduceScatter(char* input, int input_size, const int* block_start,
if (num_machines_ <= 1) {
Log::Fatal("Please initilize the network interface first");
}
if (ReduceScatterFuncPtr_ != NULL) {
return ReduceScatterFuncPtr_(input, input_size, block_start, block_len, output, reducer);
if (reduce_scatter_ext_fun_ != NULL) {
return reduce_scatter_ext_fun_(input, input_size, block_start, block_len, output, reducer);
}
if (recursive_halving_map_.need_pairwise) {
for (int i = 1; i < num_machines_; ++i) {
Expand Down

0 comments on commit 72b5495

Please sign in to comment.