Skip to content

Commit

Permalink
Fix objective functions with zero hessian (#1199)
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Jan 16, 2018
1 parent d90369a commit 5392c9e
Show file tree
Hide file tree
Showing 22 changed files with 462 additions and 219 deletions.
16 changes: 15 additions & 1 deletion docs/Features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,10 +205,24 @@ Support following metrics:

- NDCG

- MAP

- Multi class log loss

- Multi class error rate

- Fair

- Huber

- Poisson

- Quantile

- MAPE

- kullback Leibler

For more details, please refer to `Parameters <./Parameters.rst#metric-parameters>`__.

Other Features
Expand Down Expand Up @@ -269,7 +283,7 @@ References

.. _LightGBM\: A Highly Efficient Gradient Boosting Decision Tree: https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf

.. _On Grouping for Maximum Homogeneity: http://amstat.tandfonline.com/doi/abs/10.1080/01621459.1958.10501479
.. _On Grouping for Maximum Homogeneity: http://www.csiss.org/SPACE/workshops/2004/SAC/files/fisher.pdf

.. _Optimization of collective communication operations in MPICH: http://wwwi10.lrr.in.tum.de/~gerndt/home/Teaching/HPCSeminar/mpich_multi_coll.pdf

Expand Down
12 changes: 6 additions & 6 deletions docs/Parameters.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Core Parameters
- **Note**: Only can be used in CLI version.

- ``application``, default=\ ``regression``, type=enum,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``quantile_l2``,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``,
``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``,
alias=\ ``objective``, ``app``

Expand All @@ -72,7 +72,7 @@ Core Parameters

- ``quantile``, `Quantile regression`_

- ``quantile_l2``, like the ``quantile``, but L2 loss is used instead
- ``mape``, `MAPE loss`_

- ``binary``, binary `log loss`_ classification application

Expand Down Expand Up @@ -513,10 +513,6 @@ Objective Parameters

- parameter for `Fair loss`_. Will be used in ``regression`` task

- ``gaussian_eta``, default=\ ``1.0``, type=double

- parameter to control the width of Gaussian function. Will be used in ``regression_l1`` and ``huber`` losses

- ``poisson_max_delta_step``, default=\ ``0.7``, type=double

- parameter for `Poisson regression`_ to safeguard optimization
Expand Down Expand Up @@ -573,6 +569,8 @@ Metric Parameters
- ``l2_root``, root square loss, alias=\ ``root_mean_squared_error``, ``rmse``

- ``quantile``, `Quantile regression`_

- ``mape``, `MAPE loss`_

- ``huber``, `Huber loss`_

Expand Down Expand Up @@ -744,6 +742,8 @@ You can specific query/group id in data file now. Please refer to parameter ``gr

.. _Quantile regression: https://en.wikipedia.org/wiki/Quantile_regression

.. _MAPE loss: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error

.. _Fair loss: https://www.kaggle.com/c/allstate-claims-severity/discussion/24520

.. _Poisson regression: https://en.wikipedia.org/wiki/Poisson_regression
Expand Down
6 changes: 4 additions & 2 deletions docs/Quick-Start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ Some important parameters:
- ``convert_model``, for converting model file into if-else format, see more information in `Convert model parameters <./Parameters.rst#convert-model-parameters>`__

- ``application``, default=\ ``regression``, type=enum,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``quantile_l2``,
options=\ ``regression``, ``regression_l1``, ``huber``, ``fair``, ``poisson``, ``quantile``, ``mape``,
``binary``, ``multiclass``, ``multiclassova``, ``xentropy``, ``xentlambda``, ``lambdarank``,
alias=\ ``objective``, ``app``

Expand All @@ -86,7 +86,7 @@ Some important parameters:

- ``quantile``, `Quantile regression`_

- ``quantile_l2``, like the ``quantile``, but L2 loss is used instead
- ``mape``, `MAPE loss`_

- ``binary``, binary `log loss`_ classification application

Expand Down Expand Up @@ -234,6 +234,8 @@ Examples

.. _Quantile regression: https://en.wikipedia.org/wiki/Quantile_regression

.. _MAPE loss: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error

.. _log loss: https://en.wikipedia.org/wiki/Cross_entropy

.. _softmax: https://en.wikipedia.org/wiki/Softmax_function
Expand Down
4 changes: 1 addition & 3 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ struct ObjectiveConfig: public ConfigBase {
virtual ~ObjectiveConfig() {}
double sigmoid = 1.0f;
double fair_c = 1.0f;
// for Approximate Hessian With Gaussian
double gaussian_eta = 1.0f;
double poisson_max_delta_step = 0.7f;
// for lambdarank
std::vector<double> label_gain;
Expand Down Expand Up @@ -473,7 +471,7 @@ struct ParameterAlias {
"convert_model", "convert_model_language",
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"max_conflict_rate", "poisson_max_delta_step",
"histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot",
Expand Down
46 changes: 46 additions & 0 deletions include/LightGBM/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,52 @@ class Network {
return global;
}

template<class T>
static T GlobalSyncUpByMean(T& local) {
T global = (T)0;
Allreduce(reinterpret_cast<char*>(&local),
sizeof(local), sizeof(local),
reinterpret_cast<char*>(&global),
[](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const T *p1;
T *p2;
while (used_size < len) {
p1 = reinterpret_cast<const T *>(src);
p2 = reinterpret_cast<T *>(dst);
*p2 += *p1;
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return static_cast<T>(global / num_machines_);
}

template<class T>
static void GlobalSum(std::vector<T>& local) {
std::vector<T> global;
Allreduce(reinterpret_cast<char*>(local.data()),
static_cast<comm_size_t>(sizeof(T) * local.size()), sizeof(T),
reinterpret_cast<char*>(global.data()),
[](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const T *p1;
T *p2;
while (used_size < len) {
p1 = reinterpret_cast<const T *>(src);
p2 = reinterpret_cast<T *>(dst);
*p2 += *p1;
src += type_size;
dst += type_size;
used_size += type_size;
}
});
for (size_t i = 0; i < local.size(); ++i) {
local[i] = global[i];
}
}

private:

static void AllgatherBruck(char* input, const comm_size_t* block_start, const comm_size_t* block_len, char* output, comm_size_t all_size);
Expand Down
9 changes: 7 additions & 2 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ class ObjectiveFunction {

virtual bool IsConstantHessian() const { return false; }

virtual bool BoostFromAverage() const { return false; }
virtual bool IsRenewTreeOutput() const { return false; }

virtual bool GetCustomAverage(double *) const { return false; }
virtual double RenewTreeOutput(double ori_output, const double*,
const data_size_t*,
const data_size_t*,
data_size_t) const { return ori_output; }

virtual double BoostFromScore() const { return 0.0f; }

virtual bool SkipEmptyClass() const { return false; }

Expand Down
4 changes: 4 additions & 0 deletions include/LightGBM/tree_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace LightGBM {
/*! \brief forward declaration */
class Tree;
class Dataset;
class ObjectiveFunction;

/*!
* \brief Interface for tree learner
Expand Down Expand Up @@ -67,6 +68,9 @@ class TreeLearner {
*/
virtual void AddPredictionToScore(const Tree* tree, double* out_score) const = 0;

virtual void RenewTreeOutput(Tree* tree, const ObjectiveFunction* obj, const double* prediction,
data_size_t total_num_data, const data_size_t* bag_indices, data_size_t bag_cnt) const = 0;

TreeLearner() = default;
/*! \brief Disable copy */
TreeLearner& operator=(const TreeLearner&) = delete;
Expand Down
26 changes: 5 additions & 21 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,27 +640,6 @@ inline static void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, s

}

/*
* approximate hessians of absolute loss with Gaussian function
* cf. https://en.wikipedia.org/wiki/Gaussian_function
*
* y is a prediction.
* t means true target.
* g means gradient.
* eta is a parameter to control the width of Gaussian function.
* w means weights.
*/
inline static double ApproximateHessianWithGaussian(const double y, const double t, const double g,
const double eta, const double w=1.0f) {
const double diff = y - t;
const double pi = 4.0 * std::atan(1.0);
const double x = std::fabs(diff);
const double a = 2.0 * std::fabs(g) * w; // difference of two first derivatives, (zero to inf) and (zero to -inf).
const double b = 0.0;
const double c = std::max((std::fabs(y) + std::fabs(t)) * eta, 1.0e-10);
return w * std::exp(-(x - b) * (x - b) / (2.0 * c * c)) * a / (c * std::sqrt(2 * pi));
}

template <typename T>
inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>& data) {
std::vector<T*> ptr(data.size());
Expand Down Expand Up @@ -882,6 +861,11 @@ inline static const char* SkipNewLine(const char* str) {
return str;
}

template <typename T>
static int Sign(T x) {
return (x > T(0)) - (x < T(0));
}

} // namespace Common

} // namespace LightGBM
Expand Down
2 changes: 1 addition & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1956,7 +1956,7 @@ def __get_eval_info(self):
self.__name_inner_eval = \
[string_buffers[i].value.decode() for i in range_(self.__num_inner_eval)]
self.__higher_better_inner_eval = \
[name.startswith(('auc', 'ndcg', 'map')) for name in self.__name_inner_eval]
[name.startswith(('auc', 'ndcg@', 'map@')) for name in self.__name_inner_eval]

def attr(self, key):
"""Get attribute string from the Booster.
Expand Down
68 changes: 22 additions & 46 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,42 +295,15 @@ void GBDT::Bagging(int iter) {
* (i) and (ii) could be selected as say "auto_init_score" = 0 or 1 etc..
*
*/
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj, const label_t* label, data_size_t num_data) {
double ObtainAutomaticInitialScore(const ObjectiveFunction* fobj) {
double init_score = 0.0f;
bool got_custom = false;
if (fobj != nullptr) {
got_custom = fobj->GetCustomAverage(&init_score);
}
if (!got_custom) {
double sum_label = 0.0f;
#pragma omp parallel for schedule(static) reduction(+:sum_label)
for (data_size_t i = 0; i < num_data; ++i) {
sum_label += label[i];
}
init_score = sum_label / num_data;
init_score = fobj->BoostFromScore();
}
if (Network::num_machines() > 1) {
double global_init_score = 0.0f;
Network::Allreduce(reinterpret_cast<char*>(&init_score),
sizeof(init_score), sizeof(init_score),
reinterpret_cast<char*>(&global_init_score),
[](const char* src, char* dst, int type_size, comm_size_t len) {
comm_size_t used_size = 0;
const double *p1;
double *p2;
while (used_size < len) {
p1 = reinterpret_cast<const double *>(src);
p2 = reinterpret_cast<double *>(dst);
*p2 += *p1;
src += type_size;
dst += type_size;
used_size += type_size;
}
});
return global_init_score / Network::num_machines();
} else {
return init_score;
init_score = Network::GlobalSyncUpByMean(init_score);
}
return init_score;
}

void GBDT::Train(int snapshot_freq, const std::string& model_output_path) {
Expand Down Expand Up @@ -379,21 +352,23 @@ void GBDT::RefitTree(const std::vector<std::vector<int>>& tree_leaf_prediction)

double GBDT::BoostFromAverage() {
// boosting from average label; or customized "average" if implemented for the current objective
if (models_.empty()
&& gbdt_config_->boost_from_average
&& !train_score_updater_->has_init_score()
if (models_.empty() && !train_score_updater_->has_init_score()
&& num_class_ <= 1
&& objective_function_ != nullptr
&& objective_function_->BoostFromAverage()) {

auto label = train_data_->metadata().label();
double init_score = ObtainAutomaticInitialScore(objective_function_, label, num_data_);
if (std::fabs(init_score) > kEpsilon) {
train_score_updater_->AddScore(init_score, 0);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(init_score, 0);
&& objective_function_ != nullptr) {
if (gbdt_config_->boost_from_average) {
double init_score = ObtainAutomaticInitialScore(objective_function_);
if (std::fabs(init_score) > kEpsilon) {
train_score_updater_->AddScore(init_score, 0);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(init_score, 0);
}
Log::Info("Start training from score %lf", init_score);
return init_score;
}
return init_score;
} else if (std::string(objective_function_->GetName()) == std::string("regression_l1")
|| std::string(objective_function_->GetName()) == std::string("quantile")
|| std::string(objective_function_->GetName()) == std::string("mape")) {
Log::Warning("Disable boost_from_average in %s may cause the slow convergence.", objective_function_->GetName());
}
}
return 0.0f;
Expand Down Expand Up @@ -434,10 +409,9 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {
#ifdef TIMETAG
start_time = std::chrono::steady_clock::now();
#endif

const size_t bias = static_cast<size_t>(cur_tree_id) * num_data_;
std::unique_ptr<Tree> new_tree(new Tree(2));
if (class_need_train_[cur_tree_id]) {
size_t bias = static_cast<size_t>(cur_tree_id)* num_data_;
auto grad = gradients + bias;
auto hess = hessians + bias;

Expand All @@ -460,6 +434,8 @@ bool GBDT::TrainOneIter(const score_t* gradients, const score_t* hessians) {

if (new_tree->num_leaves() > 1) {
should_continue = true;
tree_learner_->RenewTreeOutput(new_tree.get(), objective_function_, train_score_updater_->score() + bias,
num_data_, bag_data_indices_.data(), bag_data_cnt_);
// shrinkage by learning rate
new_tree->Shrinkage(shrinkage_rate_);
// update score
Expand Down
5 changes: 0 additions & 5 deletions src/boosting/goss.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,6 @@ class GOSS: public GBDT {
}
}

/*!
* \brief Get Type name of this boosting object
*/
const char* SubModelName() const override { return "tree"; }

private:
std::vector<data_size_t> tmp_indice_right_;
};
Expand Down

0 comments on commit 5392c9e

Please sign in to comment.