Skip to content

Commit

Permalink
Support early stopping of prediction in CLI (#565)
Browse files Browse the repository at this point in the history
* fix multi-threading.

* fix name style.

* support in CLI version.

* remove warnings.

* Not default parameters.

* fix if...else... .

* fix bug.

* fix warning.

* refine c_api.

* fix R-package.

* fix R's warning.

* fix tests.

* fix pep8 .
  • Loading branch information
guolinke committed May 30, 2017
1 parent e04a8bb commit 6d4c7b0
Show file tree
Hide file tree
Showing 30 changed files with 296 additions and 309 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ if(USE_GPU)
endif(USE_GPU)

if(UNIX OR MINGW OR CYGWIN)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread -O3 -Wextra -Wall -std=c++11 -Wno-ignored-attributes")
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas")
endif()

if(MSVC)
Expand Down
8 changes: 4 additions & 4 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -391,15 +391,15 @@ Booster <- R6Class(
rawscore = FALSE,
predleaf = FALSE,
header = FALSE,
reshape = FALSE) {
reshape = FALSE, ...) {

# Check if number of iteration is non existent
if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}

# Predict on new data
predictor <- Predictor$new(private$handle)
predictor <- Predictor$new(private$handle, ...)
predictor$predict(data, num_iteration, rawscore, predleaf, header, reshape)

},
Expand Down Expand Up @@ -645,7 +645,7 @@ predict.lgb.Booster <- function(object, data,
rawscore = FALSE,
predleaf = FALSE,
header = FALSE,
reshape = FALSE) {
reshape = FALSE, ...) {

# Check booster existence
if (!lgb.is.Booster(object)) {
Expand All @@ -658,7 +658,7 @@ predict.lgb.Booster <- function(object, data,
rawscore,
predleaf,
header,
reshape)
reshape, ...)
}

#' Load LightGBM model
Expand Down
15 changes: 10 additions & 5 deletions R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ Predictor <- R6Class(
},

# Initialize will create a starter model
initialize = function(modelfile) {

initialize = function(modelfile, ...) {
params <- list(...)
private$params <- lgb.params2str(params)
# Create new lgb handle
handle <- lgb.new.handle()

Expand Down Expand Up @@ -86,6 +87,7 @@ Predictor <- R6Class(
as.integer(rawscore),
as.integer(predleaf),
as.integer(num_iteration),
private$params,
lgb.c_str(tmp_filename))

# Get predictions from file
Expand Down Expand Up @@ -121,7 +123,8 @@ Predictor <- R6Class(
as.integer(ncol(data)),
as.integer(rawscore),
as.integer(predleaf),
as.integer(num_iteration))
as.integer(num_iteration),
private$params)

} else if (is(data, "dgCMatrix")) {

Expand All @@ -137,7 +140,8 @@ Predictor <- R6Class(
nrow(data),
as.integer(rawscore),
as.integer(predleaf),
as.integer(num_iteration))
as.integer(num_iteration),
private$params)

} else {

Expand Down Expand Up @@ -178,5 +182,6 @@ Predictor <- R6Class(

),
private = list(handle = NULL,
need_free_handle = FALSE)
need_free_handle = FALSE,
params = "")
)
1 change: 1 addition & 0 deletions R-package/src/lightgbm-all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "../../src/boosting/boosting.cpp"
#include "../../src/boosting/gbdt.cpp"
#include "../../src/boosting/gbdt_prediction.cpp"
#include "../../src/boosting/prediction_early_stop.cpp"

// io
#include "../../src/io/bin.cpp"
Expand Down
1 change: 1 addition & 0 deletions R-package/src/lightgbm-fullcode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "./src/boosting/boosting.cpp"
#include "./src/boosting/gbdt.cpp"
#include "./src/boosting/gbdt_prediction.cpp"
#include "./src/boosting/prediction_early_stop.cpp"

// io
#include "./src/io/bin.cpp"
Expand Down
9 changes: 6 additions & 3 deletions R-package/src/lightgbm_R.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,12 +498,13 @@ LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
LGBM_SE call_state) {
R_API_BEGIN();
int pred_type = GetPredictType(is_rawscore, is_leafidx);
CHECK_CALL(LGBM_BoosterPredictForFile(R_GET_PTR(handle), R_CHAR_PTR(data_filename),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration),
R_AS_INT(data_has_header), pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter),
R_CHAR_PTR(result_filename)));
R_API_END();
}
Expand Down Expand Up @@ -534,6 +535,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state) {

Expand All @@ -552,7 +554,7 @@ LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
CHECK_CALL(LGBM_BoosterPredictForCSC(R_GET_PTR(handle),
p_indptr, C_API_DTYPE_INT32, p_indices,
p_data, C_API_DTYPE_FLOAT64, nindptr, ndata,
nrow, pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret));
nrow, pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));
R_API_END();
}

Expand All @@ -563,6 +565,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state) {

Expand All @@ -577,7 +580,7 @@ LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
int64_t out_len;
CHECK_CALL(LGBM_BoosterPredictForMat(R_GET_PTR(handle),
p_mat, C_API_DTYPE_FLOAT64, nrow, ncol, COL_MAJOR,
pred_type, R_AS_INT(num_iteration), &out_len, ptr_ret));
pred_type, R_AS_INT(num_iteration), R_CHAR_PTR(parameter), &out_len, ptr_ret));

R_API_END();
}
Expand Down
3 changes: 3 additions & 0 deletions R-package/src/lightgbm_R.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForFile_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE result_filename,
LGBM_SE call_state);

Expand Down Expand Up @@ -438,6 +439,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForCSC_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state);

Expand All @@ -463,6 +465,7 @@ LIGHTGBM_C_EXPORT LGBM_SE LGBM_BoosterPredictForMat_R(LGBM_SE handle,
LGBM_SE is_rawscore,
LGBM_SE is_leafidx,
LGBM_SE num_iteration,
LGBM_SE parameter,
LGBM_SE out_result,
LGBM_SE call_state);

Expand Down
6 changes: 6 additions & 0 deletions docs/Parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `num_iteration_predict`, default=`-1`, type=int
* only used in prediction task, used to how many trained iterations will be used in prediction.
* `<= 0` means no limit
* `pred_early_stop`, default=`false`, type=bool
* Set to `true` will use early-stopping to speed up the prediction. May affect the accuracy.
* `pred_early_stop_freq`, default=`10`, type=int
* The frequency of checking early-stopping prediction.
* `pred_early_stop_margin`, default=`10.0`, type=double
* The Threshold of margin in early-stopping prediction.
* `use_missing`, default=`true`, type=bool
* Set to `false` will disbale the special handle of missing value.

Expand Down
11 changes: 7 additions & 4 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,19 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/
virtual void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;
const PredictionEarlyStopInstance* early_stop) const = 0;

/*!
* \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
* \param early_stop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/
virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;
const PredictionEarlyStopInstance* early_stop) const = 0;

/*!
* \brief Prediction for one record with leaf index
Expand Down Expand Up @@ -220,6 +220,9 @@ class LIGHTGBM_EXPORT Boosting {
*/
virtual int NumberOfClasses() const = 0;

/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
virtual bool NeedAccuratePrediction() const = 0;

/*!
* \brief Initial work for the prediction
* \param num_iteration number of used iteration
Expand Down
29 changes: 5 additions & 24 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

typedef void* DatasetHandle;
typedef void* BoosterHandle;
typedef void* PredictionEarlyStoppingHandle;

#define C_API_DTYPE_FLOAT32 (0)
#define C_API_DTYPE_FLOAT64 (1)
Expand Down Expand Up @@ -522,7 +521,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* parameter,
const char* result_filename);

/*!
Expand Down Expand Up @@ -578,7 +577,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t num_col,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* parameter,
int64_t* out_len,
double* out_result);

Expand Down Expand Up @@ -617,7 +616,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t num_row,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* parameter,
int64_t* out_len,
double* out_result);

Expand Down Expand Up @@ -650,7 +649,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
int is_row_major,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* parameter,
int64_t* out_len,
double* out_result);

Expand Down Expand Up @@ -721,32 +720,14 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int leaf_idx,
double val);


/*!
* \brief create an new prediction early stopping instance that can be used to speed up prediction
* \param type early stopping type: "none", "multiclass" or "binary"
* \param round_period how often the classifier score is checked for the early stopping condition
* \param margin_threshold when the margin exceeds this value, early stopping kicks in and no more trees are evaluated
* \param out handle of created instance
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
int round_period,
double margin_threshold,
PredictionEarlyStoppingHandle* out);
/*!
\brief free prediction early stop instance
\return 0 when succeed
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle);

#if defined(_MSC_VER)
// exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
#else
static char* LastErrorMsg() { static thread_local char err_msg[512] = "Everything is fine"; return err_msg; }
#endif

#pragma warning(disable : 4996)
inline void LGBM_SetLastError(const char* msg) {
std::strcpy(LastErrorMsg(), msg);
}
Expand Down
8 changes: 8 additions & 0 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ struct IOConfig: public ConfigBase {
* Note: when using Index, it doesn't count the label index */
std::string categorical_column = "";
std::string device_type = "cpu";

/*! \brief Set to true if want to use early stop for the prediction */
bool pred_early_stop = false;
/*! \brief Frequency of checking the pred_early_stop */
int pred_early_stop_freq = 10;
/*! \brief Threshold of margin of pred_early_stop */
double pred_early_stop_margin = 10.0f;

LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params) override;
private:
void GetDeviceType(const std::unordered_map<std::string,
Expand Down
3 changes: 3 additions & 0 deletions include/LightGBM/objective_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ class ObjectiveFunction {

virtual int NumPredictOneRow() const { return 1; }

/*! \brief The prediction should be accurate or not. True will disable early stopping for prediction. */
virtual bool NeedAccuratePrediction() const { return true; }

virtual void ConvertOutput(const double* input, double* output) const {
output[0] = input[0];
}
Expand Down
44 changes: 22 additions & 22 deletions include/LightGBM/prediction_early_stop.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,28 +6,28 @@

#include <LightGBM/export.h>

namespace LightGBM
{
struct PredictionEarlyStopInstance
{
/// Callback function type for early stopping.
/// Takes current prediction and number of elements in prediction
/// @returns true if prediction should stop according to criterion
using FunctionType = std::function<bool(const double*, int)>;

FunctionType callbackFunction; // callback function itself
int roundPeriod; // call callbackFunction every `runPeriod` iterations
};

struct PredictionEarlyStopConfig
{
int roundPeriod;
double marginThreshold;
};

/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config);
namespace LightGBM {

#pragma warning(disable : 4099)
struct PredictionEarlyStopInstance {
/// Callback function type for early stopping.
/// Takes current prediction and number of elements in prediction
/// @returns true if prediction should stop according to criterion
using FunctionType = std::function<bool(const double*, int)>;

FunctionType callback_function; // callback function itself
int round_period; // call callback_function every `runPeriod` iterations
};

#pragma warning(disable : 4099)
struct PredictionEarlyStopConfig {
int round_period;
double margin_threshold;
};

/// Create an early stopping algorithm of type `type`, with given round_period and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance CreatePredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config);

} // namespace LightGBM

Expand Down
2 changes: 1 addition & 1 deletion include/LightGBM/utils/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
// Buffer for merge.
std::vector<_VTRanIt> temp_buf(len);
_RanIt buf = temp_buf.begin();
int s = inner_size;
size_t s = inner_size;
// Recursive merge
while (s < len) {
int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
Expand Down

0 comments on commit 6d4c7b0

Please sign in to comment.