Skip to content

Commit

Permalink
Move all prediction transform to the objective. (#383)
Browse files Browse the repository at this point in the history
* many refactors.

* remove multi_loglossova.

* fix tests.

* avoid using lambda function.

* fix some format.

* reduce branching.
  • Loading branch information
guolinke committed Apr 6, 2017
1 parent d4c4d9a commit bfb0217
Show file tree
Hide file tree
Showing 25 changed files with 589 additions and 378 deletions.
4 changes: 3 additions & 1 deletion docs/Parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can
* validation/test data, LightGBM will output metrics for these data
* support multi validation data, separate by ```,```
* ```num_iterations```, default=```10```, type=int, alias=```num_iteration```,```num_tree```,```num_trees```,```num_round```,```num_rounds```
* number of boosting iterations/trees
* number of boosting iterations
* note: ```num_tree``` here equal with ```num_iterations```. For multi-class, it actually learns ```num_class * num_iterations``` trees.
* note: For python/R package, cannot use this parameters to control number of iterations.
* ```learning_rate```, default=```0.1```, type=double, alias=```shrinkage_rate```
* shrinkage rate
* in ```dart```, it also affects normalization weights of dropped trees
Expand Down
8 changes: 4 additions & 4 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Initialization logic
* \param config Configs for boosting
* \param train_data Training data
* \param object_function Training objective function
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
virtual void Init(
const BoostingConfig* config,
const Dataset* train_data,
const ObjectiveFunction* object_function,
const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) = 0;

/*!
Expand All @@ -46,10 +46,10 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Reset training data for current boosting
* \param config Configs for boosting
* \param train_data Training data
* \param object_function Training objective function
* \param objective_function Training objective function
* \param training_metrics Training metric
*/
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
virtual void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function, const std::vector<const Metric*>& training_metrics) = 0;

/*!
* \brief Add a validation data
Expand Down
4 changes: 3 additions & 1 deletion include/LightGBM/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include <LightGBM/dataset.h>
#include <LightGBM/objective_function.h>

#include <vector>

Expand Down Expand Up @@ -33,7 +34,8 @@ class Metric {
* \brief Calcaluting and printing metric result
* \param score Current prediction score
*/
virtual std::vector<double> Eval(const double* score) const = 0;
virtual std::vector<double> Eval(const double* score, const ObjectiveFunction* objective,
int num_tree_per_iteration) const = 0;

Metric() = default;
/*! \brief Disable copy */
Expand Down
23 changes: 22 additions & 1 deletion include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
#include <LightGBM/meta.h>
#include <LightGBM/config.h>
#include <LightGBM/dataset.h>
#include <functional>

namespace LightGBM {

/*!
* \brief The interface of Objective Function.
*/
Expand Down Expand Up @@ -35,6 +35,22 @@ class ObjectiveFunction {

virtual bool IsConstantHessian() const { return false; }

virtual bool BoostFromAverage() const { return false; }

virtual bool SkipEmptyClass() const { return false; }

virtual int numTreePerIteration() const { return 1; }

virtual std::vector<double> ConvertOutput(std::vector<double>& input) const {
return input;
}

virtual double ConvertOutput(double input) const {
return input;
}

virtual std::string ToString() const = 0;

ObjectiveFunction() = default;
/*! \brief Disable copy */
ObjectiveFunction& operator=(const ObjectiveFunction&) = delete;
Expand All @@ -48,6 +64,11 @@ class ObjectiveFunction {
*/
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& type,
const ObjectiveConfig& config);

/*!
* \brief Load objective function from string object
*/
LIGHTGBM_EXPORT static ObjectiveFunction* CreateObjectiveFunction(const std::string& str);
};

} // namespace LightGBM
Expand Down
15 changes: 15 additions & 0 deletions include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,21 @@ inline void Softmax(std::vector<double>* p_rec) {
}
}

inline void Softmax(double* rec, int len) {
double wmax = rec[0];
for (int i = 1; i < len; ++i) {
wmax = std::max(rec[i], wmax);
}
double wsum = 0.0f;
for (int i = 0; i < len; ++i) {
rec[i] = std::exp(rec[i] - wmax);
wsum += rec[i];
}
for (int i = 0; i < len; ++i) {
rec[i] /= static_cast<double>(wsum);
}
}

template<typename T>
std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
std::vector<const T*> ret;
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/boosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ bool Boosting::LoadFileToBoosting(Boosting* boosting, const char* filename) {
str_buf << line << '\n';
}
if (!boosting->LoadModelFromString(str_buf.str()))
return false;
return false;
}

return true;
Expand Down
36 changes: 18 additions & 18 deletions src/boosting/dart.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,20 @@ class DART: public GBDT {
* \brief Initialization logic
* \param config Config for boosting
* \param train_data Training data
* \param object_function Training objective function
* \param objective_function Training objective function
* \param training_metrics Training metrics
* \param output_model_filename Filename of output model
*/
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, object_function, training_metrics);
void Init(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed);
sum_weight_ = 0.0f;
}

void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* objective_function,
const std::vector<const Metric*>& training_metrics) override {
GBDT::ResetTrainingData(config, train_data, objective_function, training_metrics);
}
/*!
* \brief one training iteration
Expand Down Expand Up @@ -110,10 +110,10 @@ class DART: public GBDT {
}
// drop trees
for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
models_[curr_tree]->Shrinkage(-1.0);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
}
if (!gbdt_config_->xgboost_dart_mode) {
Expand All @@ -140,16 +140,16 @@ class DART: public GBDT {
double k = static_cast<double>(drop_index_.size());
if (!gbdt_config_->xgboost_dart_mode) {
for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
// update validation score
models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class);
score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
}
// update training score
models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
Expand All @@ -158,16 +158,16 @@ class DART: public GBDT {
}
} else {
for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
for (int cur_tree_id = 0; cur_tree_id < num_tree_per_iteration_; ++cur_tree_id) {
auto curr_tree = i * num_tree_per_iteration_ + cur_tree_id;
// update validation score
models_[curr_tree]->Shrinkage(shrinkage_rate_);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class);
score_updater->AddScore(models_[curr_tree].get(), cur_tree_id);
}
// update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
train_score_updater_->AddScore(models_[curr_tree].get(), cur_tree_id);
}
if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));;
Expand Down

0 comments on commit bfb0217

Please sign in to comment.