Skip to content

Commit

Permalink
Add Random Forest Mode (#678)
Browse files Browse the repository at this point in the history
* add draft of RF.

* fix score bugs.

* fix scores.

* fix tests.

* update document

* fix GetPredictAt
  • Loading branch information
guolinke committed Jul 11, 2017
1 parent c05cfa8 commit 6a7470a
Show file tree
Hide file tree
Showing 11 changed files with 305 additions and 20 deletions.
3 changes: 2 additions & 1 deletion docs/Parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `binary`, binary classification application
* `lambdarank`, [lambdarank](https://pdfs.semanticscholar.org/fc9a/e09f9ced555558fdf1e997c0a5411fb51f15.pdf) application
* `multiclass`, multi-class classification application, should set `num_class` as well
* `boosting`, default=`gbdt`, type=enum, options=`gbdt`,`dart`, alias=`boost`,`boosting_type`
* `boosting`, default=`gbdt`, type=enum, options=`gbdt`,`rf`,`dart`,`goss`, alias=`boost`,`boosting_type`
* `gbdt`, traditional Gradient Boosting Decision Tree
* `rf`, Random Forest
* `dart`, [Dropouts meet Multiple Additive Regression Trees](https://arxiv.org/abs/1505.01866)
* `goss`, Gradient-based One-Side Sampling
* `data`, default=`""`, type=string, alias=`train`,`train_data`
Expand Down
8 changes: 5 additions & 3 deletions docs/Quick-Start.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,11 @@ Some important parameters:
* ```binary```, binary classification application
* ```lambdarank```, lambdarank application
* ```multiclass```, multi-class classification application, should set ```num_class``` as well
* ```boosting```, default=```gbdt```, type=enum, options=```gbdt```,```dart```, alias=```boost```,```boosting_type```
* ```gbdt```, traditional Gradient Boosting Decision Tree
* ```dart```, [Dropouts meet Multiple Additive Regression Trees](https://arxiv.org/abs/1505.01866)
* `boosting`, default=`gbdt`, type=enum, options=`gbdt`,`rf`,`dart`,`goss`, alias=`boost`,`boosting_type`
* `gbdt`, traditional Gradient Boosting Decision Tree
* `rf`, Random Forest
* `dart`, [Dropouts meet Multiple Additive Regression Trees](https://arxiv.org/abs/1505.01866)
* `goss`, Gradient-based One-Side Sampling
* ```data```, default=```""```, type=string, alias=```train```,```train_data```
* training data, LightGBM will train from this data
* ```valid```, default=```""```, type=multi-string, alias=```test```,```valid_data```,```test_data```
Expand Down
5 changes: 5 additions & 0 deletions src/boosting/boosting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "gbdt.h"
#include "dart.hpp"
#include "goss.hpp"
#include "rf.hpp"

namespace LightGBM {

Expand Down Expand Up @@ -34,6 +35,8 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
return new DART();
} else if (type == std::string("goss")) {
return new GOSS();
} else if (type == std::string("rf")) {
return new RF();
} else {
return nullptr;
}
Expand All @@ -47,6 +50,8 @@ Boosting* Boosting::CreateBoosting(const std::string& type, const char* filename
ret.reset(new DART());
} else if (type == std::string("goss")) {
ret.reset(new GOSS());
} else if (type == std::string("rf")) {
return new RF();
} else {
Log::Fatal("unknown boosting type %s", type.c_str());
}
Expand Down
6 changes: 6 additions & 0 deletions src/boosting/dart.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class DART: public GBDT {
sum_weight_ = 0.0f;
}

void ResetConfig(const BoostingConfig* config) override {
GBDT::ResetConfig(config);
random_for_drop_ = Random(gbdt_config_->drop_seed);
sum_weight_ = 0.0f;
}

/*!
* \brief one training iteration
*/
Expand Down
42 changes: 29 additions & 13 deletions src/boosting/gbdt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ GBDT::GBDT()
boost_from_average_(false) {
#pragma omp parallel
#pragma omp master
{
num_threads_ = omp_get_num_threads();
}
{
num_threads_ = omp_get_num_threads();
}
average_output_ = false;
}

GBDT::~GBDT() {
Expand Down Expand Up @@ -164,10 +165,9 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
}

objective_function_ = objective_function;
num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) {
is_constant_hessian_ = objective_function_->IsConstantHessian();
num_tree_per_iteration_ = objective_function_->NumTreePerIteration();
CHECK(num_tree_per_iteration_ == objective_function_->NumTreePerIteration());
} else {
is_constant_hessian_ = false;
}
Expand Down Expand Up @@ -608,6 +608,10 @@ void GBDT::UpdateScore(const Tree* tree, const int cur_tree_id) {
#endif
}

std::vector<double> GBDT::EvalOneMetric(const Metric* metric, const double* score) const {
return metric->Eval(score, objective_function_);
}

std::string GBDT::OutputMetric(int iter) {
bool need_output = (iter % gbdt_config_->output_freq) == 0;
std::string ret = "";
Expand All @@ -617,7 +621,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output) {
for (auto& sub_metric : training_metrics_) {
auto name = sub_metric->GetName();
auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf;
tmp_buf << "Iteration:" << iter
Expand All @@ -634,8 +638,7 @@ std::string GBDT::OutputMetric(int iter) {
if (need_output || early_stopping_round_ > 0) {
for (size_t i = 0; i < valid_metrics_.size(); ++i) {
for (size_t j = 0; j < valid_metrics_[i].size(); ++j) {
auto test_scores = valid_metrics_[i][j]->Eval(valid_score_updater_[i]->score(),
objective_function_);
auto test_scores = EvalOneMetric(valid_metrics_[i][j], valid_score_updater_[i]->score());
auto name = valid_metrics_[i][j]->GetName();
for (size_t k = 0; k < name.size(); ++k) {
std::stringstream tmp_buf;
Expand Down Expand Up @@ -674,16 +677,15 @@ std::vector<double> GBDT::GetEvalAt(int data_idx) const {
std::vector<double> ret;
if (data_idx == 0) {
for (auto& sub_metric : training_metrics_) {
auto scores = sub_metric->Eval(train_score_updater_->score(), objective_function_);
auto scores = EvalOneMetric(sub_metric, train_score_updater_->score());
for (auto score : scores) {
ret.push_back(score);
}
}
} else {
auto used_idx = data_idx - 1;
for (size_t j = 0; j < valid_metrics_[used_idx].size(); ++j) {
auto test_scores = valid_metrics_[used_idx][j]->Eval(valid_score_updater_[used_idx]->score(),
objective_function_);
auto test_scores = EvalOneMetric(valid_metrics_[used_idx][j], valid_score_updater_[used_idx]->score());
for (auto score : test_scores) {
ret.push_back(score);
}
Expand Down Expand Up @@ -712,7 +714,7 @@ void GBDT::GetPredictAt(int data_idx, double* out_result, int64_t* out_len) {
num_data = valid_score_updater_[used_idx]->num_data();
*out_len = static_cast<int64_t>(num_data) * num_class_;
}
if (objective_function_ != nullptr) {
if (objective_function_ != nullptr && !average_output_) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data; ++i) {
std::vector<double> tree_pred(num_tree_per_iteration_);
Expand Down Expand Up @@ -842,7 +844,12 @@ std::string GBDT::ModelToIfElse(int num_iteration) const {
// Predict
str_buf << "void GBDT::Predict(const double* features, double *output, const PredictionEarlyStopInstance* early_stop) const {" << std::endl;
str_buf << "\t" << "PredictRaw(features, output, early_stop);" << std::endl;
str_buf << "\t" << "if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t" << "if (average_output_) {" << std::endl;
str_buf << "\t\t" << "for (int k = 0; k < num_tree_per_iteration_; ++k) {" << std::endl;
str_buf << "\t\t\t" << "output[k] /= num_iteration_for_pred_;" << std::endl;
str_buf << "\t\t" << "}" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "\t" << "else if (objective_function_ != nullptr) {" << std::endl;
str_buf << "\t\t" << "objective_function_->ConvertOutput(output, output);" << std::endl;
str_buf << "\t" << "}" << std::endl;
str_buf << "}" << std::endl;
Expand Down Expand Up @@ -920,6 +927,10 @@ std::string GBDT::SaveModelToString(int num_iteration) const {
ss << "boost_from_average" << std::endl;
}

if (average_output_) {
ss << "average_output" << std::endl;
}

ss << "feature_names=" << Common::Join(feature_names_, " ") << std::endl;

ss << "feature_infos=" << Common::Join(feature_infos_, " ") << std::endl;
Expand Down Expand Up @@ -999,6 +1010,11 @@ bool GBDT::LoadModelFromString(const std::string& model_str) {
if (line.size() > 0) {
boost_from_average_ = true;
}
// get average_output
line = Common::FindFromLines(lines, "average_output");
if (line.size() > 0) {
average_output_ = true;
}
// get feature names
line = Common::FindFromLines(lines, "feature_names=");
if (line.size() > 0) {
Expand Down
7 changes: 5 additions & 2 deletions src/boosting/gbdt.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,19 @@ class GBDT: public Boosting {
* \param tree Trained tree of this iteration
* \param cur_tree_id Current tree for multiclass training
*/
void UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id);
virtual void UpdateScoreOutOfBag(const Tree* tree, const int cur_tree_id);
/*!
* \brief calculate the object function
*/
void Boosting();
virtual void Boosting();
/*!
* \brief updating score after tree was trained
* \param tree Trained tree of this iteration
* \param cur_tree_id Current tree for multiclass training
*/
virtual void UpdateScore(const Tree* tree, const int cur_tree_id);

virtual std::vector<double> EvalOneMetric(const Metric* metric, const double* score) const;
/*!
* \brief Print metric result of current iteration
* \param iter Current interation
Expand Down Expand Up @@ -373,6 +375,7 @@ class GBDT: public Boosting {
std::vector<double> class_default_output_;
bool is_constant_hessian_;
std::unique_ptr<ObjectiveFunction> loaded_objective_;
bool average_output_;
};

} // namespace LightGBM
Expand Down
6 changes: 5 additions & 1 deletion src/boosting/gbdt_prediction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ void GBDT::PredictRaw(const double* features, double* output, const PredictionEa

void GBDT::Predict(const double* features, double* output, const PredictionEarlyStopInstance* early_stop) const {
PredictRaw(features, output, early_stop);
if (objective_function_ != nullptr) {
if (average_output_) {
for (int k = 0; k < num_tree_per_iteration_; ++k) {
output[k] /= num_iteration_for_pred_;
}
} else if (objective_function_ != nullptr) {
objective_function_->ConvertOutput(output, output);
}
}
Expand Down

0 comments on commit 6a7470a

Please sign in to comment.