Skip to content

Commit

Permalink
clean code for Boosting.
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 29, 2017
1 parent 3db907c commit 6d0eae0
Show file tree
Hide file tree
Showing 14 changed files with 743 additions and 646 deletions.
40 changes: 21 additions & 19 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class LIGHTGBM_EXPORT Boosting {

/*!
* \brief Merge model from other boosting object
Will insert to the front of current boosting object
Will insert to the front of current boosting object
* \param other
*/
virtual void MergeFrom(const Boosting* other) = 0;
Expand All @@ -54,18 +54,17 @@ class LIGHTGBM_EXPORT Boosting {
* \param valid_metrics Metric for validation data
*/
virtual void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) = 0;
const std::vector<const Metric*>& valid_metrics) = 0;

virtual void Train(int snapshot_freq, const std::string& model_output_path) = 0;

/*!
* \brief Training logic
* \param gradient nullptr for using default objective, otherwise use self-defined boosting
* \param hessian nullptr for using default objective, otherwise use self-defined boosting
* \param is_eval true if need evaluation or early stop
* \return True if meet early stopping or cannot boosting
* \param gradients nullptr for using default objective, otherwise use self-defined boosting
* \param hessians nullptr for using default objective, otherwise use self-defined boosting
* \return True if cannot train anymore
*/
virtual bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) = 0;
virtual bool TrainOneIter(const score_t* gradients, const score_t* hessians) = 0;

/*!
* \brief Rollback one iteration
Expand All @@ -77,10 +76,6 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual int GetCurrentIteration() const = 0;

/*!
* \brief Eval metrics and check is met early stopping or not
*/
virtual bool EvalAndCheckEarlyStopping() = 0;
/*!
* \brief Get evaluation result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data
Expand All @@ -101,6 +96,7 @@ class LIGHTGBM_EXPORT Boosting {
* \return out_len length of returned score
*/
virtual int64_t GetNumPredictAt(int data_idx) const = 0;

/*!
* \brief Get prediction result at data_idx data
* \param data_idx 0: training data, 1: 1st validation data
Expand All @@ -115,7 +111,7 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/
virtual void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0;
Expand All @@ -124,7 +120,7 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/
virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0;
Expand All @@ -137,14 +133,14 @@ class LIGHTGBM_EXPORT Boosting {
virtual void PredictLeafIndex(
const double* features, double* output) const = 0;

/*!
/*!
* \brief Feature contributions for the model's prediction of one record
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all models are evaluated.
*/
virtual void PredictContrib(const double* features, double* output,
const PredictionEarlyStopInstance* early_stop) const = 0;
const PredictionEarlyStopInstance* early_stop) const = 0;

/*!
* \brief Dump model to json format string
Expand Down Expand Up @@ -224,10 +220,10 @@ class LIGHTGBM_EXPORT Boosting {
virtual int NumberOfTotalModel() const = 0;

/*!
* \brief Get number of trees per iteration
* \return Number of trees per iteration
* \brief Get number of models per iteration
* \return Number of models per iteration
*/
virtual int NumTreePerIteration() const = 0;
virtual int NumModelPerIteration() const = 0;

/*!
* \brief Get number of classes
Expand Down Expand Up @@ -275,6 +271,12 @@ class LIGHTGBM_EXPORT Boosting {

};

class GBDTBase : public Boosting {
public:
virtual double GetLeafValue(int tree_idx, int leaf_idx) const = 0;
virtual void SetLeafValue(int tree_idx, int leaf_idx, double val) = 0;
};

} // namespace LightGBM

#endif // LightGBM_BOOSTING_H_
2 changes: 1 addition & 1 deletion include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ObjectiveFunction {

virtual bool SkipEmptyClass() const { return false; }

virtual int NumTreePerIteration() const { return 1; }
virtual int NumModelPerIteration() const { return 1; }

virtual int NumPredictOneRow() const { return 1; }

Expand Down
9 changes: 9 additions & 0 deletions include/LightGBM/tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,15 @@ class Tree {
shrinkage_ *= rate;
}

inline void AddBias(double val) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] = val + leaf_value_[i];
}
// force to 1.0
shrinkage_ = 1.0f;
}

inline void AsConstantTree(double val) {
num_leaves_ = 1;
shrinkage_ = 1.0f;
Expand Down
13 changes: 6 additions & 7 deletions src/boosting/dart.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,19 @@ class DART: public GBDT {
/*!
* \brief one training iteration
*/
bool TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) override {
bool TrainOneIter(const score_t* gradient, const score_t* hessian) override {
is_update_score_cur_iter_ = false;
GBDT::TrainOneIter(gradient, hessian, false);
bool ret = GBDT::TrainOneIter(gradient, hessian);
if (ret) {
return ret;
}
// normalize
Normalize();
if (!gbdt_config_->uniform_drop) {
tree_weight_.push_back(shrinkage_rate_);
sum_weight_ += shrinkage_rate_;
}
if (is_eval) {
return EvalAndCheckEarlyStopping();
} else {
return false;
}
return false;
}

/*!
Expand Down

0 comments on commit 6d0eae0

Please sign in to comment.