Skip to content

Commit

Permalink
fix #822
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 18, 2017
1 parent da00522 commit 203df1b
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 69 deletions.
7 changes: 0 additions & 7 deletions include/LightGBM/application.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ class Application {
inline void Run();

private:
/*!
* \brief Global Sync by minimal, will return minimal T across nodes
* \param local Local data
* \return minimal values across nodes
*/
template<typename T>
T GlobalSyncUpByMin(T& local);

/*! \brief Load parameters from command line and config file*/
void LoadParameters(int argc, char** argv);
Expand Down
50 changes: 50 additions & 0 deletions include/LightGBM/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,56 @@ class Network {
const int* block_start, const int* block_len, char* output,
const ReduceFunction& reducer);

template<class T>
static T GlobalSyncUpByMin(T& local) {
T global = local;
Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(T);
const T *p1;
T *p2;
while (used_size < len) {
p1 = reinterpret_cast<const T *>(src);
p2 = reinterpret_cast<T *>(dst);
if (*p1 < *p2) {
std::memcpy(dst, src, type_size);
}
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return global;
}

template<class T>
static T GlobalSyncUpByMax(T& local) {
T global = local;
Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(T);
const T *p1;
T *p2;
while (used_size < len) {
p1 = reinterpret_cast<const T *>(src);
p2 = reinterpret_cast<T *>(dst);
if (*p1 > *p2) {
std::memcpy(dst, src, type_size);
}
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return global;
}

private:
/*! \brief Number of all machines */
static int num_machines_;
Expand Down
38 changes: 4 additions & 34 deletions src/application/application.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ void Application::LoadData() {

// sync up random seed for data partition
if (config_.is_parallel_find_bin) {
config_.io_config.data_random_seed =
GlobalSyncUpByMin<int>(config_.io_config.data_random_seed);
config_.io_config.data_random_seed = Network::GlobalSyncUpByMin(config_.io_config.data_random_seed);
}

DatasetLoader dataset_loader(config_.io_config, predict_fun,
Expand Down Expand Up @@ -190,13 +189,12 @@ void Application::InitTrain() {
// need init network
Network::Init(config_.network_config);
Log::Info("Finished initializing network");
// sync global random seed for feature patition
config_.boosting_config.tree_config.feature_fraction_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.tree_config.feature_fraction_seed);
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction_seed);
config_.boosting_config.tree_config.feature_fraction =
GlobalSyncUpByMin<double>(config_.boosting_config.tree_config.feature_fraction);
Network::GlobalSyncUpByMin(config_.boosting_config.tree_config.feature_fraction);
config_.boosting_config.drop_seed =
GlobalSyncUpByMin<int>(config_.boosting_config.drop_seed);
Network::GlobalSyncUpByMin(config_.boosting_config.drop_seed);
}

// create boosting
Expand Down Expand Up @@ -255,33 +253,5 @@ void Application::ConvertModel() {
boosting_->SaveModelToIfElse(-1, config_.io_config.convert_model.c_str());
}

template<typename T>
T Application::GlobalSyncUpByMin(T& local) {
T global = local;
if (!config_.is_parallel) {
// no need to sync if not parallel learning
return global;
}
Network::Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[](const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(T);
const T *p1;
T *p2;
while (used_size < len) {
p1 = reinterpret_cast<const T *>(src);
p2 = reinterpret_cast<T *>(dst);
if (*p1 < *p2) {
std::memcpy(dst, src, type_size);
}
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return global;
}

} // namespace LightGBM
36 changes: 8 additions & 28 deletions src/io/dataset_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -738,35 +738,8 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
// if have multi-machines, need to find bin distributed
// different machines will find bin for different features

int max_bin = 0;
int total_num_feature = static_cast<int>(sample_values.size());
for (int i = 0; i < total_num_feature; ++i) {
max_bin = std::max(max_bin, bin_mappers[i]->num_bin());
}
std::pair<int, int> local_sync_info(max_bin, total_num_feature);
std::pair<int, int> global_sync_info(max_bin, total_num_feature);
// sync global max_bin and total_num_feature
Network::Allreduce(reinterpret_cast<char*>(&local_sync_info),
sizeof(local_sync_info), sizeof(global_sync_info),
reinterpret_cast<char*>(&global_sync_info),
[] (const char* src, char* dst, int len) {
int used_size = 0;
const int type_size = sizeof(std::pair<int, int>);
const std::pair<int, int> *p1;
std::pair<int, int> *p2;
while (used_size < len) {
p1 = reinterpret_cast<const std::pair<int, int> *>(src);
p2 = reinterpret_cast<std::pair<int, int> *>(dst);
p2->first = std::max(p1->first, p2->first);
// ignore the rare features
p2->second = std::min(p1->second, p2->second);
src += type_size;
dst += type_size;
used_size += type_size;
}
});
max_bin = global_sync_info.first;
total_num_feature = global_sync_info.second;
total_num_feature = Network::GlobalSyncUpByMin(total_num_feature);
// start and len will store the process feature indices for different machines
// machine i will find bins for features in [ start[i], start[i] + len[i] )
std::vector<int> start(num_machines);
Expand Down Expand Up @@ -797,6 +770,13 @@ void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
OMP_LOOP_EX_END();
}
OMP_THROW_EX();
int max_bin = 0;
for (int i = 0; i < len[rank]; ++i) {
if (bin_mappers[i] != nullptr) {
max_bin = std::max(max_bin, bin_mappers[i]->num_bin());
}
}
max_bin = Network::GlobalSyncUpByMax(max_bin);
// get size of bin mapper with max_bin size
int type_size = BinMapper::SizeForSpecificBin(max_bin);
// since sizes of different feature may not be same, we expand all bin mapper to type_size
Expand Down

0 comments on commit 203df1b

Please sign in to comment.