Skip to content

Commit

Permalink
[python] add keep_training_booster (#673)
Browse files Browse the repository at this point in the history
* add keep_training_booster

* use model string

* reset handle; free dataset
  • Loading branch information
wxchan authored and guolinke committed Jul 7, 2017
1 parent ac73638 commit f7d190a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
11 changes: 6 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1245,7 +1245,7 @@ def __init__(self, params=None, train_set=None, model_file=None, silent=False):
self.__num_class = out_num_class.value
self.pandas_categorical = _load_pandas_categorical(model_file)
elif 'model_str' in params:
self.__load_model_from_string(params['model_str'])
self._load_model_from_string(params['model_str'])
else:
raise TypeError('Need at least one training dataset or model file to create booster instance')

Expand All @@ -1257,7 +1257,7 @@ def __copy__(self):
return self.__deepcopy__(None)

def __deepcopy__(self, _):
model_str = self.__save_model_to_string()
model_str = self._save_model_to_string()
booster = Booster({'model_str': model_str})
booster.pandas_categorical = self.pandas_categorical
return booster
Expand All @@ -1268,7 +1268,7 @@ def __getstate__(self):
this.pop('train_set', None)
this.pop('valid_sets', None)
if handle is not None:
this["handle"] = self.__save_model_to_string()
this["handle"] = self._save_model_to_string()
return this

def __setstate__(self, state):
Expand All @@ -1286,6 +1286,7 @@ def __setstate__(self, state):
def free_dataset(self):
self.__dict__.pop('train_set', None)
self.__dict__.pop('valid_sets', None)
self.__num_dataset = 0

def set_train_data_name(self, name):
self.__train_data_name = name
Expand Down Expand Up @@ -1505,7 +1506,7 @@ def save_model(self, filename, num_iteration=-1):
c_str(filename)))
_save_pandas_categorical(filename, self.pandas_categorical)

def __load_model_from_string(self, model_str):
def _load_model_from_string(self, model_str):
"""[Private] Load model from string"""
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterLoadModelFromString(
Expand All @@ -1518,7 +1519,7 @@ def __load_model_from_string(self, model_str):
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value

def __save_model_to_string(self, num_iteration=-1):
def _save_model_to_string(self, num_iteration=-1):
"""[Private] Save model to string"""
if num_iteration <= 0:
num_iteration = self.best_iteration
Expand Down
10 changes: 9 additions & 1 deletion python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def train(params, train_set, num_boost_round=100,
fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, evals_result=None,
verbose_eval=True, learning_rates=None, callbacks=None):
verbose_eval=True, learning_rates=None,
keep_training_booster=False, callbacks=None):
"""
Train with given parameters.
Expand Down Expand Up @@ -80,6 +81,10 @@ def train(params, train_set, num_boost_round=100,
in terms of current number of round (e.g. yields learning rate decay)
- list l: learning_rate = l[current_round]
- function f: learning_rate = f(current_round)
keep_training_booster : boolean
Whether the return booster will be used to keep training.
If false, will convert into _InnerPredictor before return.
You can still use _InnerPredictor as init_model for future continue training.
callbacks : list of callback functions
List of callback functions that are applied at each iteration.
See Callbacks in Python-API.md for more information.
Expand Down Expand Up @@ -200,6 +205,9 @@ def train(params, train_set, num_boost_round=100,
booster.best_score = collections.defaultdict(dict)
for dataset_name, eval_name, score, _ in evaluation_result_list:
booster.best_score[dataset_name][eval_name] = score
if not keep_training_booster:
booster._load_model_from_string(booster._save_model_to_string())
booster.free_dataset()
return booster


Expand Down

0 comments on commit f7d190a

Please sign in to comment.