Skip to content

Commit

Permalink
[python] save all param values into model file (#2589)
Browse files Browse the repository at this point in the history
* save all param values into model file

* revert storing predict params

* do not save params for predict and convert tasks

* fixed test: 10 is found successfully for default 100

* specify more params as no-save
  • Loading branch information
StrikerRUS committed Mar 6, 2020
1 parent 2051223 commit ba15a16
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 28 deletions.
2 changes: 1 addition & 1 deletion helpers/parameter_generator.py
Expand Up @@ -315,7 +315,7 @@ def gen_parameter_code(config_hpp, config_out_cpp):
str_to_write += " std::stringstream str_buf;\n"
for x in infos:
for y in x:
if "[doc-only]" in y:
if "[doc-only]" in y or "[no-save]" in y:
continue
param_type = y["inner_type"][0]
name = y["name"][0]
Expand Down
25 changes: 23 additions & 2 deletions include/LightGBM/config.h
Expand Up @@ -3,8 +3,10 @@
* Licensed under the MIT License. See LICENSE file in the project root for license information.
*
* \note
* desc and descl2 fields must be written in reStructuredText format;
* nested sections can be placed only at the bottom of parent's section
* - desc and descl2 fields must be written in reStructuredText format;
* - nested sections can be placed only at the bottom of parent's section;
* - [doc-only] tag indicates that only documentation for this param should be generated and all other actions are performed manually;
* - [no-save] tag indicates that this param should not be saved into a model text representation.
*/
#ifndef LIGHTGBM_CONFIG_H_
#define LIGHTGBM_CONFIG_H_
Expand Down Expand Up @@ -83,12 +85,14 @@ struct Config {

#pragma region Core Parameters

// [no-save]
// [doc-only]
// alias = config_file
// desc = path of config file
// desc = **Note**: can be used only in CLI version
std::string config = "";

// [no-save]
// [doc-only]
// type = enum
// default = train
Expand Down Expand Up @@ -482,18 +486,21 @@ struct Config {
// desc = ``< 0``: Fatal, ``= 0``: Error (Warning), ``= 1``: Info, ``> 1``: Debug
int verbosity = 1;

// [no-save]
// alias = model_input, model_in
// desc = filename of input model
// desc = for ``prediction`` task, this model will be applied to prediction data
// desc = for ``train`` task, training will be continued from this model
// desc = **Note**: can be used only in CLI version
std::string input_model = "";

// [no-save]
// alias = model_output, model_out
// desc = filename of output model in training
// desc = **Note**: can be used only in CLI version
std::string output_model = "LightGBM_model.txt";

// [no-save]
// alias = save_period
// desc = frequency of saving model file snapshot
// desc = set this to positive value to enable this function. For example, the model file will be snapshotted at each iteration if ``snapshot_freq=1``
Expand Down Expand Up @@ -626,6 +633,7 @@ struct Config {
// desc = see `this file <https://github.com/microsoft/LightGBM/tree/master/examples/regression/forced_bins.json>`__ as an example
std::string forcedbins_filename = "";

// [no-save]
// alias = is_save_binary, is_save_binary_file
// desc = if ``true``, LightGBM will save the dataset (including validation data) to a binary file. This speed ups the data loading for the next time
// desc = **Note**: ``init_score`` is not saved in binary file
Expand All @@ -636,22 +644,26 @@ struct Config {

#pragma region Predict Parameters

// [no-save]
// desc = used only in ``prediction`` task
// desc = used to specify how many trained iterations will be used in prediction
// desc = ``<= 0`` means no limit
int num_iteration_predict = -1;

// [no-save]
// alias = is_predict_raw_score, predict_rawscore, raw_score
// desc = used only in ``prediction`` task
// desc = set this to ``true`` to predict only the raw scores
// desc = set this to ``false`` to predict transformed scores
bool predict_raw_score = false;

// [no-save]
// alias = is_predict_leaf_index, leaf_index
// desc = used only in ``prediction`` task
// desc = set this to ``true`` to predict with leaf index of all trees
bool predict_leaf_index = false;

// [no-save]
// alias = is_predict_contrib, contrib
// desc = used only in ``prediction`` task
// desc = set this to ``true`` to estimate `SHAP values <https://arxiv.org/abs/1706.06060>`__, which represent how each feature contributes to each prediction
Expand All @@ -660,25 +672,30 @@ struct Config {
// desc = **Note**: unlike the shap package, with ``predict_contrib`` we return a matrix with an extra column, where the last column is the expected value
bool predict_contrib = false;

// [no-save]
// desc = used only in ``prediction`` task
// desc = control whether or not LightGBM raises an error when you try to predict on data with a different number of features than the training data
// desc = if ``false`` (the default), a fatal error will be raised if the number of features in the dataset you predict on differs from the number seen during training
// desc = if ``true``, LightGBM will attempt to predict on whatever data you provide. This is dangerous because you might get incorrect predictions, but you could use it in situations where it is difficult or expensive to generate some features and you are very confident that they were never chosen for splits in the model
// desc = **Note**: be very careful setting this parameter to ``true``
bool predict_disable_shape_check = false;

// [no-save]
// desc = used only in ``prediction`` task
// desc = if ``true``, will use early-stopping to speed up the prediction. May affect the accuracy
bool pred_early_stop = false;

// [no-save]
// desc = used only in ``prediction`` task
// desc = the frequency of checking early-stopping prediction
int pred_early_stop_freq = 10;

// [no-save]
// desc = used only in ``prediction`` task
// desc = the threshold of margin in early-stopping prediction
double pred_early_stop_margin = 10.0;

// [no-save]
// alias = predict_result, prediction_result, predict_name, prediction_name, pred_name, name_pred
// desc = used only in ``prediction`` task
// desc = filename of prediction result
Expand All @@ -689,12 +706,14 @@ struct Config {

#pragma region Convert Parameters

// [no-save]
// desc = used only in ``convert_model`` task
// desc = only ``cpp`` is supported yet; for conversion model to other languages consider using `m2cgen <https://github.com/BayesWitnesses/m2cgen>`__ utility
// desc = if ``convert_model_language`` is set and ``task=train``, the model will be also converted
// desc = **Note**: can be used only in CLI version
std::string convert_model_language = "";

// [no-save]
// alias = convert_model_file
// desc = used only in ``convert_model`` task
// desc = output filename of converted model
Expand Down Expand Up @@ -820,12 +839,14 @@ struct Config {
// desc = support multiple metrics, separated by ``,``
std::vector<std::string> metric;

// [no-save]
// check = >0
// alias = output_freq
// desc = frequency for metric output
// desc = **Note**: can be used only in CLI version
int metric_freq = 1;

// [no-save]
// alias = training_metric, is_training_metric, train_metric
// desc = set this to ``true`` to output metric result over training dataset
// desc = **Note**: can be used only in CLI version
Expand Down
4 changes: 2 additions & 2 deletions python-package/lightgbm/basic.py
Expand Up @@ -1751,7 +1751,7 @@ def __init__(self, params=None, train_set=None, model_file=None, model_str=None,
self.set_network(machines,
local_listen_port=params.get("local_listen_port", 12400),
listen_time_out=params.get("listen_time_out", 120),
num_machines=params.get("num_machines", num_machines))
num_machines=params.setdefault("num_machines", num_machines))
break
# construct booster object
train_set.construct()
Expand Down Expand Up @@ -2641,7 +2641,7 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
train_set = Dataset(data, label, silent=True)
new_params = copy.deepcopy(self.params)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set, silent=True)
new_booster = Booster(new_params, train_set)
# Copy models
_safe_call(_LIB.LGBM_BoosterMerge(
new_booster.handle,
Expand Down
12 changes: 6 additions & 6 deletions python-package/lightgbm/engine.py
Expand Up @@ -146,13 +146,13 @@ def train(params, train_set, num_boost_round=100,
if alias in params:
num_boost_round = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
break
first_metric_only = params.pop('first_metric_only', False)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)

if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
Expand Down Expand Up @@ -504,13 +504,13 @@ def cv(params, train_set, num_boost_round=100,
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
num_boost_round = params.pop(alias)
break
params["num_iterations"] = num_boost_round
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
warnings.warn("Found `{}` in params. Will use it instead of argument".format(alias))
early_stopping_rounds = params.pop(alias)
break
first_metric_only = params.pop('first_metric_only', False)
params["early_stopping_round"] = early_stopping_rounds
first_metric_only = params.get('first_metric_only', False)

if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
Expand Down
17 changes: 0 additions & 17 deletions src/io/config_auto.cpp
Expand Up @@ -641,9 +641,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[cegb_penalty_feature_lazy: " << Common::Join(cegb_penalty_feature_lazy, ",") << "]\n";
str_buf << "[cegb_penalty_feature_coupled: " << Common::Join(cegb_penalty_feature_coupled, ",") << "]\n";
str_buf << "[verbosity: " << verbosity << "]\n";
str_buf << "[input_model: " << input_model << "]\n";
str_buf << "[output_model: " << output_model << "]\n";
str_buf << "[snapshot_freq: " << snapshot_freq << "]\n";
str_buf << "[max_bin: " << max_bin << "]\n";
str_buf << "[max_bin_by_feature: " << Common::Join(max_bin_by_feature, ",") << "]\n";
str_buf << "[min_data_in_bin: " << min_data_in_bin << "]\n";
Expand All @@ -663,18 +660,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[ignore_column: " << ignore_column << "]\n";
str_buf << "[categorical_feature: " << categorical_feature << "]\n";
str_buf << "[forcedbins_filename: " << forcedbins_filename << "]\n";
str_buf << "[save_binary: " << save_binary << "]\n";
str_buf << "[num_iteration_predict: " << num_iteration_predict << "]\n";
str_buf << "[predict_raw_score: " << predict_raw_score << "]\n";
str_buf << "[predict_leaf_index: " << predict_leaf_index << "]\n";
str_buf << "[predict_contrib: " << predict_contrib << "]\n";
str_buf << "[predict_disable_shape_check: " << predict_disable_shape_check << "]\n";
str_buf << "[pred_early_stop: " << pred_early_stop << "]\n";
str_buf << "[pred_early_stop_freq: " << pred_early_stop_freq << "]\n";
str_buf << "[pred_early_stop_margin: " << pred_early_stop_margin << "]\n";
str_buf << "[output_result: " << output_result << "]\n";
str_buf << "[convert_model_language: " << convert_model_language << "]\n";
str_buf << "[convert_model: " << convert_model << "]\n";
str_buf << "[objective_seed: " << objective_seed << "]\n";
str_buf << "[num_class: " << num_class << "]\n";
str_buf << "[is_unbalance: " << is_unbalance << "]\n";
Expand All @@ -689,8 +674,6 @@ std::string Config::SaveMembersToString() const {
str_buf << "[lambdarank_truncation_level: " << lambdarank_truncation_level << "]\n";
str_buf << "[lambdarank_norm: " << lambdarank_norm << "]\n";
str_buf << "[label_gain: " << Common::Join(label_gain, ",") << "]\n";
str_buf << "[metric_freq: " << metric_freq << "]\n";
str_buf << "[is_provide_training_metric: " << is_provide_training_metric << "]\n";
str_buf << "[eval_at: " << Common::Join(eval_at, ",") << "]\n";
str_buf << "[multi_error_top_k: " << multi_error_top_k << "]\n";
str_buf << "[auc_mu_weights: " << Common::Join(auc_mu_weights, ",") << "]\n";
Expand Down
2 changes: 2 additions & 0 deletions tests/python_package_test/test_engine.py
Expand Up @@ -747,6 +747,8 @@ def train_and_predict(init_model=None, return_model=False):
ret_origin = train_and_predict(init_model=gbm)
other_ret = []
gbm.save_model('lgb.model')
with open('lgb.model') as f: # check all params are logged into model file correctly
self.assertNotEqual(f.read().find("[num_iterations: 10]"), -1)
other_ret.append(train_and_predict(init_model='lgb.model'))
gbm_load = lgb.Booster(model_file='lgb.model')
other_ret.append(train_and_predict(init_model=gbm_load))
Expand Down

0 comments on commit ba15a16

Please sign in to comment.