Skip to content

Commit

Permalink
Add prediction early stopping (#550)
Browse files Browse the repository at this point in the history
* Add early stopping for prediction

* Fix GBDT if-else prediction with early stopping

* Small C++ embelishments to early stopping API and functions

* Fix early stopping efficiency issue by creating a singleton for no early stopping

* Python improvements to early stopping API

* Add assertion check for binary and multiclass prediction score length

* Update vcxproj and vcxproj.filters with new early stopping files

* Remove inline from PredictRaw(), the linker was not able to find it otherwise
  • Loading branch information
cbecker authored and guolinke committed May 29, 2017
1 parent 2cca828 commit 993bbd5
Show file tree
Hide file tree
Showing 16 changed files with 375 additions and 76 deletions.
9 changes: 7 additions & 2 deletions include/LightGBM/boosting.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace LightGBM {
class Dataset;
class ObjectiveFunction;
class Metric;
class PredictionEarlyStopInstance;

/*!
* \brief The interface for Boosting
Expand Down Expand Up @@ -116,15 +117,19 @@ class LIGHTGBM_EXPORT Boosting {
* \brief Prediction for one record, not sigmoid transform
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/
virtual void PredictRaw(const double* features, double* output) const = 0;
virtual void PredictRaw(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;

/*!
* \brief Prediction for one record, sigmoid transformation will be used if needed
* \param feature_values Feature value on this record
* \param output Prediction result for this record
* \param earlyStop Early stopping instance. If nullptr, no early stopping is applied and all trees are evaluated.
*/
virtual void Predict(const double* features, double* output) const = 0;
virtual void Predict(const double* features, double* output,
const PredictionEarlyStopInstance* earlyStop = nullptr) const = 0;

/*!
* \brief Prediction for one record with leaf index
Expand Down
27 changes: 27 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

typedef void* DatasetHandle;
typedef void* BoosterHandle;
typedef void* PredictionEarlyStoppingHandle;

#define C_API_DTYPE_FLOAT32 (0)
#define C_API_DTYPE_FLOAT64 (1)
Expand Down Expand Up @@ -521,6 +522,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForFile(BoosterHandle handle,
int data_has_header,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
const char* result_filename);

/*!
Expand Down Expand Up @@ -560,6 +562,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterCalcNumPredict(BoosterHandle handle,
* C_API_PREDICT_RAW_SCORE: raw score
* C_API_PREDICT_LEAF_INDEX: leaf index
* \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param early_stop_handle early stopping to use for prediction. If null, no early stopping is applied
* \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when succeed, -1 when failure happens
Expand All @@ -575,6 +578,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
int64_t num_col,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len,
double* out_result);

Expand All @@ -597,6 +601,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSR(BoosterHandle handle,
* C_API_PREDICT_RAW_SCORE: raw score
* C_API_PREDICT_LEAF_INDEX: leaf index
* \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param early_stop_handle early stopping to use for prediction. If null, no early stopping is applied
* \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when succeed, -1 when failure happens
Expand All @@ -612,6 +617,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
int64_t num_row,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len,
double* out_result);

Expand All @@ -631,6 +637,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle,
* C_API_PREDICT_RAW_SCORE: raw score
* C_API_PREDICT_LEAF_INDEX: leaf index
* \param num_iteration number of iteration for prediction, <= 0 means no limit
* \param early_stop_handle early stopping to use for prediction. If null, no early stopping is applied
* \param out_len len of output result
* \param out_result used to set a pointer to array, should allocate memory before call this function
* \return 0 when succeed, -1 when failure happens
Expand All @@ -643,6 +650,7 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForMat(BoosterHandle handle,
int is_row_major,
int predict_type,
int num_iteration,
const PredictionEarlyStoppingHandle early_stop_handle,
int64_t* out_len,
double* out_result);

Expand Down Expand Up @@ -713,6 +721,25 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterSetLeafValue(BoosterHandle handle,
int leaf_idx,
double val);


/*!
* \brief create an new prediction early stopping instance that can be used to speed up prediction
* \param type early stopping type: "none", "multiclass" or "binary"
* \param round_period how often the classifier score is checked for the early stopping condition
* \param margin_threshold when the margin exceeds this value, early stopping kicks in and no more trees are evaluated
* \param out handle of created instance
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceCreate(const char* type,
int round_period,
double margin_threshold,
PredictionEarlyStoppingHandle* out);
/*!
\brief free prediction early stop instance
\return 0 when succeed
*/
LIGHTGBM_C_EXPORT int LGBM_PredictionEarlyStopInstanceFree(const PredictionEarlyStoppingHandle handle);

#if defined(_MSC_VER)
// exception handle and error msg
static char* LastErrorMsg() { static __declspec(thread) char err_msg[512] = "Everything is fine"; return err_msg; }
Expand Down
34 changes: 34 additions & 0 deletions include/LightGBM/prediction_early_stop.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef LIGHTGBM_PREDICTION_EARLY_STOP_H_
#define LIGHTGBM_PREDICTION_EARLY_STOP_H_

#include <functional>
#include <string>

#include <LightGBM/export.h>

namespace LightGBM
{
struct PredictionEarlyStopInstance
{
/// Callback function type for early stopping.
/// Takes current prediction and number of elements in prediction
/// @returns true if prediction should stop according to criterion
using FunctionType = std::function<bool(const double*, int)>;

FunctionType callbackFunction; // callback function itself
int roundPeriod; // call callbackFunction every `runPeriod` iterations
};

struct PredictionEarlyStopConfig
{
int roundPeriod;
double marginThreshold;
};

/// Create an early stopping algorithm of type `type`, with given roundPeriod and margin threshold
LIGHTGBM_EXPORT PredictionEarlyStopInstance createPredictionEarlyStopInstance(const std::string& type,
const PredictionEarlyStopConfig& config);

} // namespace LightGBM

#endif // LIGHTGBM_PREDICTION_EARLY_STOP_H_
4 changes: 2 additions & 2 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import absolute_import

from .basic import Booster, Dataset
from .basic import Booster, Dataset, PredictionEarlyStopInstance
from .callback import (early_stopping, print_evaluation, record_evaluation,
reset_parameter)
from .engine import cv, train
Expand All @@ -23,7 +23,7 @@

__version__ = 0.2

__all__ = ['Dataset', 'Booster',
__all__ = ['Dataset', 'Booster', 'PredictionEarlyStopInstance',
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'print_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
Expand Down
59 changes: 53 additions & 6 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ class _InnerPredictor(object):
Only used for prediction, usually used for continued-train
Note: Can convert from Booster, but cannot convert to Booster
"""
def __init__(self, model_file=None, booster_handle=None):
def __init__(self, model_file=None, booster_handle=None, early_stop_instance=None):
"""Initialize the _InnerPredictor. Not expose to user
Parameters
Expand All @@ -305,6 +305,8 @@ def __init__(self, model_file=None, booster_handle=None):
Path to the model file.
booster_handle : Handle of Booster
use handle to init
early_stop_instance: object of type PredictionEarlyStopInstance
If None, no early stopping is applied
"""
self.handle = ctypes.c_void_p()
self.__is_manage_handle = True
Expand Down Expand Up @@ -339,6 +341,11 @@ def __init__(self, model_file=None, booster_handle=None):
else:
raise TypeError('Need Model file or Booster handle to create a predictor')

if early_stop_instance is None:
self.early_stop_instance = PredictionEarlyStopInstance("none")
else:
self.early_stop_instance = early_stop_instance

def __del__(self):
if self.__is_manage_handle:
_safe_call(_LIB.LGBM_BoosterFree(self.handle))
Expand Down Expand Up @@ -385,6 +392,7 @@ def predict(self, data, num_iteration=-1,
int_data_has_header = 1 if data_has_header else 0
if num_iteration > self.num_total_iteration:
num_iteration = self.num_total_iteration

if isinstance(data, string_type):
with _temp_file() as f:
_safe_call(_LIB.LGBM_BoosterPredictForFile(
Expand All @@ -393,6 +401,7 @@ def predict(self, data, num_iteration=-1,
ctypes.c_int(int_data_has_header),
ctypes.c_int(predict_type),
ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
c_str(f.name)))
lines = f.readlines()
nrow = len(lines)
Expand All @@ -409,7 +418,7 @@ def predict(self, data, num_iteration=-1,
predict_type)
elif isinstance(data, DataFrame):
preds, nrow = self.__pred_for_np2d(data.values, num_iteration,
predict_type)
predict_type, early_stop_instance_handle)

This comment has been minimized.

Copy link
@wxchan

wxchan Jun 16, 2017

Contributor

@cbecker just notice this: why is a early_stop_instance_handle here? look like a typo

This comment has been minimized.

Copy link
@cbecker

cbecker Jun 16, 2017

Author Contributor

Indeed, this is a leftover which probably makes this code crash. I guess there is no test for this function; otherwise tests would fail (we're passing more arguments than what it expects)

This comment has been minimized.

Copy link
@wxchan

wxchan Jun 16, 2017

Contributor

actually some preprocessing ensures this branch will never be called, I will remove it.

else:
try:
csr = scipy.sparse.csr_matrix(data)
Expand Down Expand Up @@ -466,6 +475,7 @@ def __pred_for_np2d(self, mat, num_iteration, predict_type):
ctypes.c_int(C_API_IS_ROW_MAJOR),
ctypes.c_int(predict_type),
ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value:
Expand Down Expand Up @@ -496,6 +506,7 @@ def __pred_for_csr(self, csr, num_iteration, predict_type):
ctypes.c_int64(csr.shape[1]),
ctypes.c_int(predict_type),
ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value:
Expand Down Expand Up @@ -526,6 +537,7 @@ def __pred_for_csc(self, csc, num_iteration, predict_type):
ctypes.c_int64(csc.shape[0]),
ctypes.c_int(predict_type),
ctypes.c_int(num_iteration),
self.early_stop_instance.handle,
ctypes.byref(out_num_preds),
preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double))))
if n_preds != out_num_preds.value:
Expand Down Expand Up @@ -1568,7 +1580,8 @@ def dump_model(self, num_iteration=-1):
ptr_string_buffer))
return json.loads(string_buffer.value.decode())

def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True):
def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data_has_header=False, is_reshape=True,
early_stop_instance=None):
"""
Predict logic
Expand All @@ -1587,19 +1600,21 @@ def predict(self, data, num_iteration=-1, raw_score=False, pred_leaf=False, data
Used for txt data
is_reshape : bool
Reshape to (nrow, ncol) if true
early_stop_instance: object of type PredictionEarlyStopInstance.
If None, no early stopping is applied
Returns
-------
Prediction result
"""
predictor = self._to_predictor()
predictor = self._to_predictor(early_stop_instance)
if num_iteration <= 0:
num_iteration = self.best_iteration
return predictor.predict(data, num_iteration, raw_score, pred_leaf, data_has_header, is_reshape)

def _to_predictor(self):
def _to_predictor(self, early_stop_instance=None):
"""Convert to predictor"""
predictor = _InnerPredictor(booster_handle=self.handle)
predictor = _InnerPredictor(booster_handle=self.handle, early_stop_instance=early_stop_instance)
predictor.pandas_categorical = self.pandas_categorical
return predictor

Expand Down Expand Up @@ -1785,3 +1800,35 @@ def set_attr(self, **kwargs):
self.__attr[key] = value
else:
self.__attr.pop(key, None)


class PredictionEarlyStopInstance(object):
""""PredictionEarlyStopInstance in LightGBM."""
def __init__(self, early_stop_type="none", round_period=20, margin_threshold=1.5):
"""
Create an early stopping object
Parameters
----------
early_stop_type: string
None, "none", "binary" or "multiclass". Regression is not supported.
round_period : int
The score will be checked every round_period to check if the early stopping criteria is met
margin_threshold : double
Early stopping will kick in when the margin is greater than margin_threshold
"""
self.handle = ctypes.c_void_p(0)
self.__attr = {}

if early_stop_type is None:
early_stop_type = "none"

_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceCreate(
c_str(early_stop_type),
ctypes.c_int(round_period),
ctypes.c_double(margin_threshold),
ctypes.byref(self.handle)))

def __del__(self):
if self.handle is not None:
_safe_call(_LIB.LGBM_PredictionEarlyStopInstanceFree(self.handle))
11 changes: 6 additions & 5 deletions src/application/predictor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ class Predictor {
* \param is_predict_leaf_index True if output leaf index instead of prediction score
*/
Predictor(Boosting* boosting, int num_iteration,
bool is_raw_score, bool is_predict_leaf_index) {
bool is_raw_score, bool is_predict_leaf_index,
const PredictionEarlyStopInstance* earlyStop = nullptr) {
#pragma omp parallel
#pragma omp master
{
Expand All @@ -54,17 +55,17 @@ class Predictor {

} else {
if (is_raw_score) {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->PredictRaw(predict_buf_[tid].data(), output);
boosting_->PredictRaw(predict_buf_[tid].data(), output, earlyStop);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
} else {
predict_fun_ = [this](const std::vector<std::pair<int, double>>& features, double* output) {
predict_fun_ = [this, earlyStop](const std::vector<std::pair<int, double>>& features, double* output) {
int tid = omp_get_thread_num();
CopyToPredictBuffer(predict_buf_[tid].data(), features);
boosting_->Predict(predict_buf_[tid].data(), output);
boosting_->Predict(predict_buf_[tid].data(), output, earlyStop);
ClearPredictBuffer(predict_buf_[tid].data(), predict_buf_[tid].size(), features);
};
}
Expand Down

0 comments on commit 993bbd5

Please sign in to comment.