Skip to content

Commit

Permalink
fix the objective init issues in distributed mode (#2420)
Browse files Browse the repository at this point in the history
* fix bug

* fix include
  • Loading branch information
guolinke committed Sep 19, 2019
1 parent 0237492 commit a119639
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 19 deletions.
40 changes: 22 additions & 18 deletions include/LightGBM/network.h
Expand Up @@ -188,7 +188,6 @@ class Network {
});
return global;
}

template<class T>
static T GlobalSyncUpByMax(T& local) {
T global = local;
Expand All @@ -214,25 +213,30 @@ class Network {
}

template<class T>
static T GlobalSyncUpByMean(T& local) {
static T GlobalSyncUpBySum(T& local) {
T global = (T)0;
Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const T *p1;
T *p2;
while (used_size < len) {
p1 = reinterpret_cast<const T *>(src);
p2 = reinterpret_cast<T *>(dst);
*p2 += *p1;
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return static_cast<T>(global / num_machines_);
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const T* p1;
T* p2;
while (used_size < len) {
p1 = reinterpret_cast<const T*>(src);
p2 = reinterpret_cast<T*>(dst);
*p2 += *p1;
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return static_cast<T>(global);
}

template<class T>
static T GlobalSyncUpByMean(T& local) {
return static_cast<T>(GlobalSyncUpBySum(local) / num_machines_);
}

template<class T>
Expand Down
7 changes: 6 additions & 1 deletion src/objective/binary_objective.hpp
Expand Up @@ -5,6 +5,7 @@
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_

#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>

#include <string>
Expand Down Expand Up @@ -72,6 +73,11 @@ class BinaryLogloss: public ObjectiveFunction {
++cnt_negative;
}
}
num_pos_data_ = cnt_positive;
if (Network::num_machines() > 1) {
cnt_positive = Network::GlobalSyncUpBySum(cnt_positive);
cnt_negative = Network::GlobalSyncUpBySum(cnt_negative);
}
need_train_ = true;
if (cnt_negative == 0 || cnt_positive == 0) {
Log::Warning("Contains only one class");
Expand All @@ -96,7 +102,6 @@ class BinaryLogloss: public ObjectiveFunction {
}
}
label_weights_[1] *= scale_pos_weight_;
num_pos_data_ = cnt_positive;
}

void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
Expand Down
7 changes: 7 additions & 0 deletions src/objective/multiclass_objective.hpp
Expand Up @@ -5,6 +5,7 @@
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>

#include <string>
Expand Down Expand Up @@ -66,6 +67,12 @@ class MulticlassSoftmax: public ObjectiveFunction {
if (weights_ == nullptr) {
sum_weight = num_data_;
}
if (Network::num_machines() > 1) {
sum_weight = Network::GlobalSyncUpBySum(sum_weight);
for (int i = 0; i < num_class_; ++i) {
class_init_probs_[i] = Network::GlobalSyncUpBySum(class_init_probs_[i]);
}
}
for (int i = 0; i < num_class_; ++i) {
class_init_probs_[i] /= sum_weight;
}
Expand Down

0 comments on commit a119639

Please sign in to comment.