From 5b5b98235e4fa8c1eed67f57bf5c409bbe1a09e0 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Sat, 13 Apr 2019 13:47:33 +0300 Subject: [PATCH] [python] make possibility to create Booster from string official (#2098) --- python-package/lightgbm/basic.py | 15 +++++++++------ tests/python_package_test/test_engine.py | 2 +- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index b0c0f1563c2..643bc5fa6bf 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -397,7 +397,7 @@ def __init__(self, model_file=None, booster_handle=None, pred_parameter=None): self.num_total_iteration = out_num_iterations.value self.pandas_categorical = None else: - raise TypeError('Need Model file or Booster handle to create a predictor') + raise TypeError('Need model_file or booster_handle to create a predictor') pred_parameter = {} if pred_parameter is None else pred_parameter self.pred_parameter = param_dict_to_str(pred_parameter) @@ -1578,7 +1578,7 @@ def dump_text(self, filename): class Booster(object): """Booster in LightGBM.""" - def __init__(self, params=None, train_set=None, model_file=None, silent=False): + def __init__(self, params=None, train_set=None, model_file=None, model_str=None, silent=False): """Initialize the Booster. Parameters @@ -1589,6 +1589,8 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False): Training dataset. model_file : string or None, optional (default=None) Path to the model file. + model_str : string or None, optional (default=None) + Model will be loaded from this string. silent : bool, optional (default=False) Whether to print messages during construction. """ @@ -1666,10 +1668,11 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False): ctypes.byref(out_num_class))) self.__num_class = out_num_class.value self.pandas_categorical = _load_pandas_categorical(file_name=model_file) - elif 'model_str' in params: - self.model_from_string(params['model_str'], False) + elif model_str is not None: + self.model_from_string(model_str, not silent) else: - raise TypeError('Need at least one training dataset or model file to create booster instance') + raise TypeError('Need at least one training dataset or model file or model string ' + 'to create Booster instance') self.params = params def __del__(self): @@ -1689,7 +1692,7 @@ def __copy__(self): def __deepcopy__(self, _): model_str = self.model_to_string(num_iteration=-1) - booster = Booster({'model_str': model_str}) + booster = Booster(model_str=model_str) return booster def __getstate__(self): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 86cb403bfd6..1dc8f959ff0 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -583,7 +583,7 @@ def test_pandas_categorical(self): model_str = gbm4.model_to_string() gbm4.model_from_string(model_str, False) pred5 = gbm4.predict(X_test) - gbm5 = lgb.Booster({'model_str': model_str}) + gbm5 = lgb.Booster(model_str=model_str) pred6 = gbm5.predict(X_test) np.testing.assert_almost_equal(pred0, pred1) np.testing.assert_almost_equal(pred0, pred2)