Skip to content

Commit

Permalink
simplify start_iteration param for predict in Python and some code cl…
Browse files Browse the repository at this point in the history
…eanup for start_iteration (#3288)

* simplify start_iteration param for predict in Python and some code cleanup for start_iteration

* revert docs changes about the prediction result shape
  • Loading branch information
StrikerRUS committed Aug 11, 2020
1 parent 97d5758 commit 877d58f
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 42 deletions.
4 changes: 2 additions & 2 deletions R-package/R/lgb.Booster.R
Expand Up @@ -491,11 +491,11 @@ Booster <- R6::R6Class(
header = FALSE,
reshape = FALSE, ...) {

# Check if number of iteration is non existent
# Check if number of iteration is non existent
if (is.null(num_iteration)) {
num_iteration <- self$best_iter
}
# Check if start iteration is non existent
# Check if start iteration is non existent
if (is.null(start_iteration)) {
start_iteration <- 0L
}
Expand Down
5 changes: 2 additions & 3 deletions R-package/tests/testthat/test_Predictor.R
Expand Up @@ -38,9 +38,8 @@ test_that("start_iteration works correctly", {
, label = train$label
, num_leaves = 4L
, learning_rate = 0.6
, nrounds = 100L
, nrounds = 50L
, objective = "binary"
, save_name = tempfile(fileext = ".model")
, valids = list("test" = dtest)
, early_stopping_rounds = 2L
)
Expand All @@ -50,7 +49,7 @@ test_that("start_iteration works correctly", {
pred2 <- rep(0.0, length(pred1))
pred_contrib2 <- rep(0.0, length(pred2))
step <- 11L
end_iter <- 99L
end_iter <- 49L
if (bst$best_iter != -1L) {
end_iter <- bst$best_iter - 1L
}
Expand Down
18 changes: 8 additions & 10 deletions python-package/lightgbm/basic.py
Expand Up @@ -2813,7 +2813,7 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl
default=json_default_with_numpy))
return ret

def predict(self, data, start_iteration=None, num_iteration=None,
def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, **kwargs):
"""Make a prediction.
Expand All @@ -2823,14 +2823,14 @@ def predict(self, data, start_iteration=None, num_iteration=None,
data : string, numpy array, pandas DataFrame, H2O DataTable's Frame or scipy.sparse
Data source for prediction.
If string, it represents the path to txt file.
start_iteration : int or None, optional (default=None)
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.
If <= 0, starts from the first iteration.
num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction.
If None, if the best iteration exists and start_iteration is None or <= 0, the best iteration is used;
otherwise, all iterations from start_iteration are used.
If <= 0, all iterations from start_iteration are used (no limits).
Total number of iterations used in the prediction.
If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
otherwise, all iterations from ``start_iteration`` are used (no limits).
If <= 0, all iterations from ``start_iteration`` are used (no limits).
raw_score : bool, optional (default=False)
Whether to predict raw scores.
pred_leaf : bool, optional (default=False)
Expand Down Expand Up @@ -2861,10 +2861,8 @@ def predict(self, data, start_iteration=None, num_iteration=None,
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
"""
predictor = self._to_predictor(copy.deepcopy(kwargs))
if start_iteration is None or start_iteration < 0:
start_iteration = 0
if num_iteration is None:
if start_iteration == 0:
if start_iteration <= 0:
num_iteration = self.best_iteration
else:
num_iteration = -1
Expand Down
28 changes: 15 additions & 13 deletions python-package/lightgbm/sklearn.py
Expand Up @@ -612,7 +612,7 @@ def _get_meta_data(collection, name, i):
del train_set, valid_sets
return self

def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted value for each sample.
Expand All @@ -622,13 +622,14 @@ def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
Input features matrix.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
start_iteration : int or None, optional (default=None)
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.
If <= 0, starts from the first iteration.
num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction.
If None, if the best iteration exists, it is used; otherwise, all trees are used.
If <= 0, all trees are used (no limits).
Total number of iterations used in the prediction.
If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
otherwise, all iterations from ``start_iteration`` are used (no limits).
If <= 0, all iterations from ``start_iteration`` are used (no limits).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Expand Down Expand Up @@ -835,7 +836,7 @@ def fit(self, X, y,

fit.__doc__ = LGBMModel.fit.__doc__

def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,
def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
Expand All @@ -848,7 +849,7 @@ def predict(self, X, raw_score=False, start_iteration=None, num_iteration=None,

predict.__doc__ = LGBMModel.predict.__doc__

def predict_proba(self, X, raw_score=False, start_iteration=None, num_iteration=None,
def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
"""Return the predicted probability for each class for each sample.
Expand All @@ -858,13 +859,14 @@ def predict_proba(self, X, raw_score=False, start_iteration=None, num_iteration=
Input features matrix.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
start_iteration : int or None, optional (default=None)
start_iteration : int, optional (default=0)
Start index of the iteration to predict.
If None or <= 0, starts from the first iteration.
If <= 0, starts from the first iteration.
num_iteration : int or None, optional (default=None)
Limit number of iterations in the prediction.
If None, if the best iteration exists, it is used; otherwise, all trees are used.
If <= 0, all trees are used (no limits).
Total number of iterations used in the prediction.
If None, if the best iteration exists and start_iteration <= 0, the best iteration is used;
otherwise, all iterations from ``start_iteration`` are used (no limits).
If <= 0, all iterations from ``start_iteration`` are used (no limits).
pred_leaf : bool, optional (default=False)
Whether to predict leaf index.
pred_contrib : bool, optional (default=False)
Expand Down
2 changes: 0 additions & 2 deletions src/application/predictor.hpp
Expand Up @@ -226,7 +226,6 @@ class Predictor {
data_size_t, const std::vector<std::string>& lines) {
std::vector<std::pair<int, double>> oneline_features;
std::vector<std::string> result_to_write(lines.size());
Log::Warning("before predict_fun_ is called");
OMP_INIT_EX();
#pragma omp parallel for schedule(static) firstprivate(oneline_features)
for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
Expand All @@ -241,7 +240,6 @@ class Predictor {
result_to_write[i] = str_result;
OMP_LOOP_EX_END();
}
Log::Warning("after predict_fun_ is called");
OMP_THROW_EX();
for (data_size_t i = 0; i < static_cast<data_size_t>(result_to_write.size()); ++i) {
writer->Write(result_to_write[i].c_str(), result_to_write[i].size());
Expand Down
2 changes: 1 addition & 1 deletion src/boosting/gbdt_prediction.cpp
Expand Up @@ -78,7 +78,7 @@ void GBDT::PredictByMap(const std::unordered_map<int, double>& features, double*

void GBDT::PredictLeafIndex(const double* features, double* output) const {
int start_tree = start_iteration_for_pred_ * num_tree_per_iteration_;
int num_trees = num_iteration_for_pred_ * num_tree_per_iteration_;
int num_trees = num_iteration_for_pred_ * num_tree_per_iteration_;
const auto* models_ptr = models_.data() + start_tree;
for (int i = 0; i < num_trees; ++i) {
output[i] = models_ptr[i]->PredictLeafIndex(features);
Expand Down
20 changes: 9 additions & 11 deletions tests/python_package_test/test_engine.py
Expand Up @@ -2321,7 +2321,7 @@ def inner_test(X, y, params, early_stopping_rounds):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test)
booster = lgb.train(params, train_data, num_boost_round=100, early_stopping_rounds=early_stopping_rounds, valid_sets=[valid_data])
booster = lgb.train(params, train_data, num_boost_round=50, early_stopping_rounds=early_stopping_rounds, valid_sets=[valid_data])

# test that the predict once with all iterations equals summed results with start_iteration and num_iteration
all_pred = booster.predict(X, raw_score=True)
Expand All @@ -2330,17 +2330,15 @@ def inner_test(X, y, params, early_stopping_rounds):
for step in steps:
pred = np.zeros_like(all_pred)
pred_contrib = np.zeros_like(all_pred_contrib)
for start_iter in range(0, 100, step):
pred += booster.predict(X, num_iteration=step, start_iteration=start_iter, raw_score=True)
pred_contrib += booster.predict(X, num_iteration=step, start_iteration=start_iter, pred_contrib=True)
for start_iter in range(0, 50, step):
pred += booster.predict(X, start_iteration=start_iter, num_iteration=step, raw_score=True)
pred_contrib += booster.predict(X, start_iteration=start_iter, num_iteration=step, pred_contrib=True)
np.testing.assert_allclose(all_pred, pred)
np.testing.assert_allclose(all_pred_contrib, pred_contrib)
# test the case where start_iteration <= 0, and num_iteration is None
pred1 = booster.predict(X, start_iteration=-1)
pred2 = booster.predict(X, num_iteration=booster.best_iteration)
pred3 = booster.predict(X, num_iteration=booster.best_iteration, start_iteration=0)
np.testing.assert_allclose(pred1, pred2)
np.testing.assert_allclose(pred1, pred3)

# test the case where start_iteration > 0, and num_iteration <= 0
pred4 = booster.predict(X, start_iteration=10, num_iteration=-1)
Expand All @@ -2351,14 +2349,14 @@ def inner_test(X, y, params, early_stopping_rounds):

# test the case where start_iteration > 0, and num_iteration <= 0, with pred_leaf=True
pred4 = booster.predict(X, start_iteration=10, num_iteration=-1, pred_leaf=True)
pred5 = booster.predict(X, start_iteration=10, num_iteration=90, pred_leaf=True)
pred5 = booster.predict(X, start_iteration=10, num_iteration=40, pred_leaf=True)
pred6 = booster.predict(X, start_iteration=10, num_iteration=0, pred_leaf=True)
np.testing.assert_allclose(pred4, pred5)
np.testing.assert_allclose(pred4, pred6)

# test the case where start_iteration > 0, and num_iteration <= 0, with pred_contrib=True
pred4 = booster.predict(X, start_iteration=10, num_iteration=-1, pred_contrib=True)
pred5 = booster.predict(X, start_iteration=10, num_iteration=90, pred_contrib=True)
pred5 = booster.predict(X, start_iteration=10, num_iteration=40, pred_contrib=True)
pred6 = booster.predict(X, start_iteration=10, num_iteration=0, pred_contrib=True)
np.testing.assert_allclose(pred4, pred5)
np.testing.assert_allclose(pred4, pred6)
Expand All @@ -2373,7 +2371,7 @@ def inner_test(X, y, params, early_stopping_rounds):
}
# test both with and without early stopping
inner_test(X, y, params, early_stopping_rounds=1)
inner_test(X, y, params, early_stopping_rounds=10)
inner_test(X, y, params, early_stopping_rounds=5)
inner_test(X, y, params, early_stopping_rounds=None)

# test for multi-class
Expand All @@ -2387,7 +2385,7 @@ def inner_test(X, y, params, early_stopping_rounds):
}
# test both with and without early stopping
inner_test(X, y, params, early_stopping_rounds=1)
inner_test(X, y, params, early_stopping_rounds=10)
inner_test(X, y, params, early_stopping_rounds=5)
inner_test(X, y, params, early_stopping_rounds=None)

# test for binary
Expand All @@ -2400,5 +2398,5 @@ def inner_test(X, y, params, early_stopping_rounds):
}
# test both with and without early stopping
inner_test(X, y, params, early_stopping_rounds=1)
inner_test(X, y, params, early_stopping_rounds=10)
inner_test(X, y, params, early_stopping_rounds=5)
inner_test(X, y, params, early_stopping_rounds=None)

0 comments on commit 877d58f

Please sign in to comment.