Skip to content

Commit

Permalink
clean code for network functions
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Dec 14, 2017
1 parent 0a7a408 commit 159e9a1
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 20 deletions.
10 changes: 5 additions & 5 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,11 @@ LIGHTGBM_C_EXPORT int LGBM_NetworkInit(const char* machines,
*/
LIGHTGBM_C_EXPORT int LGBM_NetworkFree();

LIGHTGBM_C_EXPORT int LGBM_GetFuncions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank);
LIGHTGBM_C_EXPORT int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank);

// exception handle and error msg
static char* LastErrorMsg() { static THREAD_LOCAL char err_msg[512] = "Everything is fine"; return err_msg; }
Expand Down
6 changes: 3 additions & 3 deletions include/LightGBM/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,9 @@ class Network {
/*! \brief set variables and function ptrs */
static void SetRank(int rank) { rank_ = rank;}
static void SetNumMachines(int num_machines) { num_machines_ = num_machines; }
static void SetAllReduce(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatter(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgather(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }
static void SetAllReduceFunction(AllreduceFunction AllreduceFuncPtr) { AllreduceFuncPtr_ = AllreduceFuncPtr;}
static void SetReduceScatterFunction(ReduceScatterFunction ReduceScatterFuncPtr) { ReduceScatterFuncPtr_ = ReduceScatterFuncPtr; }
static void SetAllgatherFunction(AllgatherFunction AllgatherFuncPtr) { AllgatherFuncPtr_ = AllgatherFuncPtr; }

private:
/*! \brief Number of all machines */
Expand Down
23 changes: 11 additions & 12 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1220,31 +1220,30 @@ int LGBM_NetworkFree() {
API_END();
}

int LGBM_GetFuncions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank) {
int LGBM_NetworkInitWithFunctions(void* AllreduceFuncPtr,
void* ReduceScatterFuncPtr,
void* AllgatherFuncPtr,
int num_machines,
int rank) {
API_BEGIN();
if(num_machines > 1) {
auto func1 = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
if (num_machines > 1) {
auto allreduce_fun = [AllreduceFuncPtr](char* arg1, int arg2, int arg3, char* arg4, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, int, char*, const ReduceFunctionInC&))AllreduceFuncPtr;
return tmp(arg1, arg2, arg3, arg4, ptr);
};
Network::SetAllReduce(func1);
auto func2 = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
Network::SetAllReduceFunction(allreduce_fun);
auto reduce_scatter_fun = [ReduceScatterFuncPtr](char* arg1, int arg2, const int* arg3, const int* arg4, char* arg5, const ReduceFunction& func) {
auto ptr = *func.target<ReduceFunctionInC>();
auto tmp = (void(*)(char*, int, const int*, const int*, char*, const ReduceFunctionInC&))ReduceScatterFuncPtr;
return tmp(arg1, arg2, arg3, arg4, arg5, ptr);
};
Network::SetReduceScatter(func2);
Network::SetAllgather((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetReduceScatterFunction(reduce_scatter_fun);
Network::SetAllgatherFunction((void(*)(char*, int, char*))AllgatherFuncPtr);
Network::SetNumMachines(num_machines);
Network::SetRank(rank);
}
API_END();

}
// ---- start of some help functions

Expand Down

0 comments on commit 159e9a1

Please sign in to comment.