Skip to content

Commit

Permalink
[python] make possibility to create Booster from string official (#2098)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Apr 13, 2019
1 parent 0a4a7a8 commit 5b5b982
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
15 changes: 9 additions & 6 deletions python-package/lightgbm/basic.py
Expand Up @@ -397,7 +397,7 @@ def __init__(self, model_file=None, booster_handle=None, pred_parameter=None):
self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = None
else:
raise TypeError('Need Model file or Booster handle to create a predictor')
raise TypeError('Need model_file or booster_handle to create a predictor')

pred_parameter = {} if pred_parameter is None else pred_parameter
self.pred_parameter = param_dict_to_str(pred_parameter)
Expand Down Expand Up @@ -1578,7 +1578,7 @@ def dump_text(self, filename):
class Booster(object):
"""Booster in LightGBM."""

def __init__(self, params=None, train_set=None, model_file=None, silent=False):
def __init__(self, params=None, train_set=None, model_file=None, model_str=None, silent=False):
"""Initialize the Booster.
Parameters
Expand All @@ -1589,6 +1589,8 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
Training dataset.
model_file : string or None, optional (default=None)
Path to the model file.
model_str : string or None, optional (default=None)
Model will be loaded from this string.
silent : bool, optional (default=False)
Whether to print messages during construction.
"""
Expand Down Expand Up @@ -1666,10 +1668,11 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif 'model_str' in params:
self.model_from_string(params['model_str'], False)
elif model_str is not None:
self.model_from_string(model_str, not silent)
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')
raise TypeError('Need at least one training dataset or model file or model string '
'to create Booster instance')
self.params = params

def __del__(self):
Expand All @@ -1689,7 +1692,7 @@ def __copy__(self):

def __deepcopy__(self, _):
model_str = self.model_to_string(num_iteration=-1)
booster = Booster({'model_str': model_str})
booster = Booster(model_str=model_str)
return booster

def __getstate__(self):
Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_engine.py
Expand Up @@ -583,7 +583,7 @@ def test_pandas_categorical(self):
model_str = gbm4.model_to_string()
gbm4.model_from_string(model_str, False)
pred5 = gbm4.predict(X_test)
gbm5 = lgb.Booster({'model_str': model_str})
gbm5 = lgb.Booster(model_str=model_str)
pred6 = gbm5.predict(X_test)
np.testing.assert_almost_equal(pred0, pred1)
np.testing.assert_almost_equal(pred0, pred2)
Expand Down

0 comments on commit 5b5b982

Please sign in to comment.