Skip to content

Commit

Permalink
[python] Dataset params back up before training (#786)
Browse files Browse the repository at this point in the history
* params back up

* refine logic
  • Loading branch information
wxchan authored and guolinke committed Aug 18, 2017
1 parent 2367b46 commit ed1e4f8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
9 changes: 8 additions & 1 deletion python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Wrapper c_api of LightGBM"""
from __future__ import absolute_import

import copy
import ctypes
import os
import warnings
Expand Down Expand Up @@ -591,11 +592,12 @@ def __init__(self, data, label=None, max_bin=255, reference=None,
self.silent = silent
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.params = params
self.params = copy.deepcopy(params)
self.free_raw_data = free_raw_data
self.used_indices = None
self._predictor = None
self.pandas_categorical = None
self.params_back_up = None

def __del__(self):
self._free_handle()
Expand Down Expand Up @@ -872,8 +874,13 @@ def _update_params(self, params):
if not self.params:
self.params = params
else:
self.params_back_up = copy.deepcopy(self.params)
self.params.update(params)

def _reverse_update_params(self):
self.params = copy.deepcopy(self.params_back_up)
self.params_back_up = None

def set_field(self, field_name, data):
"""Set property into the Dataset.
Expand Down
18 changes: 11 additions & 7 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,13 @@ def train(params, train_set, num_boost_round=100,
continue
if not isinstance(valid_data, Dataset):
raise TypeError("Traninig only accepts Dataset object")
valid_data._update_params(params)
valid_data.set_reference(train_set)
reduced_valid_sets.append(valid_data)
if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i])
else:
name_valid_sets.append('valid_' + str(i))
for valid_data in valid_sets:
valid_data._update_params(params)
"""process callbacks"""
if callbacks is None:
callbacks = set()
Expand Down Expand Up @@ -165,11 +164,16 @@ def train(params, train_set, num_boost_round=100,
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))

"""construct booster"""
booster = Booster(params=params, train_set=train_set)
if is_valid_contain_train:
booster.set_train_data_name(train_data_name)
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
try:
booster = Booster(params=params, train_set=train_set)
if is_valid_contain_train:
booster.set_train_data_name(train_data_name)
for valid_set, name_valid_set in zip(reduced_valid_sets, name_valid_sets):
booster.add_valid(valid_set, name_valid_set)
finally:
train_set._reverse_update_params()
for valid_set in reduced_valid_sets:
valid_set._reverse_update_params()
booster.best_iteration = 0

"""start training"""
Expand Down

0 comments on commit ed1e4f8

Please sign in to comment.