From 2d4b01b78448982e964c4b04b5895277ebf5d4fc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 26 Dec 2018 11:50:07 -0500 Subject: [PATCH] Reset grower and splitter instead of instanciating new ones --- benchmarks/bench_higgs_boson.py | 5 ++- pygbm/gradient_boosting.py | 21 ++++++----- pygbm/grower.py | 22 +++++++---- pygbm/splitting.py | 65 +++++++++++++++++++++------------ 4 files changed, 71 insertions(+), 42 deletions(-) diff --git a/benchmarks/bench_higgs_boson.py b/benchmarks/bench_higgs_boson.py index d7cb38e..6057a9e 100644 --- a/benchmarks/bench_higgs_boson.py +++ b/benchmarks/bench_higgs_boson.py @@ -86,7 +86,10 @@ def load_data(): n_iter_no_change=None, random_state=0, verbose=1) -pygbm_model.fit(data_train, target_train) +@profile +def fit(): + pygbm_model.fit(data_train, target_train) +fit() toc = time() predicted_test = pygbm_model.predict(data_test) roc_auc = roc_auc_score(target_test, predicted_test) diff --git a/pygbm/gradient_boosting.py b/pygbm/gradient_boosting.py index 0e37918..6ada9ce 100644 --- a/pygbm/gradient_boosting.py +++ b/pygbm/gradient_boosting.py @@ -228,6 +228,17 @@ def fit(self, X, y): y_train, raw_predictions) predictors.append([]) + grower = TreeGrower( + X_binned_train, + max_bins=self.max_bins, + n_bins_per_feature=n_bins_per_feature, + max_leaf_nodes=self.max_leaf_nodes, + max_depth=self.max_depth, + min_samples_leaf=self.min_samples_leaf, + l2_regularization=self.l2_regularization, + shrinkage=self.learning_rate, + hessian_is_constant=self.loss_.hessian_is_constant + ) # Build `n_trees_per_iteration` trees. for k, (gradients_at_k, hessians_at_k) in enumerate(zip( @@ -238,15 +249,7 @@ def fit(self, X, y): # n_trees_per_iteration is 1 and xxxx_at_k is equivalent to the # whole array. - grower = TreeGrower( - X_binned_train, gradients_at_k, hessians_at_k, - max_bins=self.max_bins, - n_bins_per_feature=n_bins_per_feature, - max_leaf_nodes=self.max_leaf_nodes, - max_depth=self.max_depth, - min_samples_leaf=self.min_samples_leaf, - l2_regularization=self.l2_regularization, - shrinkage=self.learning_rate) + grower.reset(gradients_at_k, hessians_at_k) grower.grow() acc_apply_split_time += grower.total_apply_split_time diff --git a/pygbm/grower.py b/pygbm/grower.py index c77d000..fdbf2f1 100644 --- a/pygbm/grower.py +++ b/pygbm/grower.py @@ -160,10 +160,11 @@ class TreeGrower: The shrinkage parameter to apply to the leaves values, also known as learning rate. """ - def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, - max_depth=None, min_samples_leaf=20, min_gain_to_split=0., - max_bins=256, n_bins_per_feature=None, l2_regularization=0., - min_hessian_to_split=1e-3, shrinkage=1.): + def __init__(self, X_binned, max_leaf_nodes=None, max_depth=None, + min_samples_leaf=20, min_gain_to_split=0., max_bins=256, + n_bins_per_feature=None, l2_regularization=0., + min_hessian_to_split=1e-3, shrinkage=1., + hessian_is_constant=False): self._validate_parameters(X_binned, max_leaf_nodes, max_depth, min_samples_leaf, min_gain_to_split, @@ -178,15 +179,20 @@ def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None, dtype=np.uint32) self.splitting_context = SplittingContext( - X_binned, max_bins, n_bins_per_feature, gradients, - hessians, l2_regularization, min_hessian_to_split, - min_samples_leaf, min_gain_to_split) + X_binned, max_bins, n_bins_per_feature, l2_regularization, + min_hessian_to_split, min_samples_leaf, min_gain_to_split, + hessian_is_constant) + self.max_leaf_nodes = max_leaf_nodes self.max_depth = max_depth self.min_samples_leaf = min_samples_leaf self.X_binned = X_binned self.min_gain_to_split = min_gain_to_split self.shrinkage = shrinkage + self.hessian_is_constant = hessian_is_constant + + def reset(self, gradients, hessians): + self.splitting_context.reset(gradients, hessians) self.splittable_nodes = [] self.finalized_leaves = [] self.total_find_split_time = 0. # time spent finding the best splits @@ -237,7 +243,7 @@ def _intilialize_root(self): """Initialize root node and finalize it if needed.""" n_samples = self.X_binned.shape[0] depth = 0 - if self.splitting_context.constant_hessian: + if self.hessian_is_constant: hessian = self.splitting_context.hessians[0] * n_samples else: hessian = self.splitting_context.hessians.sum() diff --git a/pygbm/splitting.py b/pygbm/splitting.py index 56ae412..65f668f 100644 --- a/pygbm/splitting.py +++ b/pygbm/splitting.py @@ -81,7 +81,7 @@ def __init__(self, gain=-1., feature_idx=0, bin_idx=0, ('ordered_hessians', float32[::1]), ('sum_gradients', float32), ('sum_hessians', float32), - ('constant_hessian', uint8), + ('hessian_is_constant', uint8), ('constant_hessian_value', float32), ('l2_regularization', float32), ('min_hessian_to_split', float32), @@ -126,9 +126,9 @@ class SplittingContext: be ignored. """ def __init__(self, X_binned, max_bins, n_bins_per_feature, - gradients, hessians, l2_regularization, - min_hessian_to_split=1e-3, min_samples_leaf=20, - min_gain_to_split=0.): + l2_regularization, min_hessian_to_split=1e-3, + min_samples_leaf=20, min_gain_to_split=0., + hessian_is_constant=False): self.X_binned = X_binned self.n_features = X_binned.shape[1] @@ -136,22 +136,17 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, # last bins may be unused if n_bins_per_feature[f] < max_bins self.max_bins = max_bins self.n_bins_per_feature = n_bins_per_feature - self.gradients = gradients - self.hessians = hessians - # for root node, gradients and hessians are already ordered - self.ordered_gradients = gradients.copy() - self.ordered_hessians = hessians.copy() - self.sum_gradients = self.gradients.sum() - self.sum_hessians = self.hessians.sum() - self.constant_hessian = hessians.shape[0] == 1 self.l2_regularization = l2_regularization self.min_hessian_to_split = min_hessian_to_split self.min_samples_leaf = min_samples_leaf self.min_gain_to_split = min_gain_to_split - if self.constant_hessian: - self.constant_hessian_value = self.hessians[0] # 1 scalar + + self.hessian_is_constant = hessian_is_constant + self.ordered_gradients = np.empty(X_binned.shape[0], dtype=np.float32) + if self.hessian_is_constant: + self.ordered_hessians = np.empty(1, dtype=np.float32) # won't be used anyway else: - self.constant_hessian_value = float32(1.) # won't be used anyway + self.ordered_hessians = np.empty(X_binned.shape[0], dtype=np.float32) # The partition array maps each sample index into the leaves of the # tree (a leaf in this context is a node that isn't splitted yet, not @@ -162,10 +157,32 @@ def __init__(self, X_binned, max_bins, n_bins_per_feature, # partition = [cef|abdghijkl] # we have 2 leaves, the left one is at position 0 and the second one at # position 3. The order of the samples is irrelevant. - self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32) + self.partition = np.empty(X_binned.shape[0], dtype=np.uint32) # buffers used in split_indices to support parallel splitting. - self.left_indices_buffer = np.empty_like(self.partition) - self.right_indices_buffer = np.empty_like(self.partition) + self.left_indices_buffer = np.empty(X_binned.shape[0], dtype=np.uint32) + self.right_indices_buffer = np.empty(X_binned.shape[0], dtype=np.uint32) + + # TODO: parallelize this + def reset(self, gradients, hessians): + self.gradients = gradients + self.hessians = hessians + + # for root node, gradients and hessians are already ordered + self.sum_gradients = self.gradients.sum() + self.sum_hessians = self.hessians.sum() + + n_samples = gradients.shape[0] + for i in range(n_samples): + self.ordered_gradients[i] = gradients[i] + if self.hessian_is_constant: + self.constant_hessian_value = self.hessians[0] # 1 scalar + else: + self.constant_hessian_value = float32(1.) # won't be used anyway + for i in range(n_samples): + self.ordered_hessians[i] = hessians[i] + + for i in range(n_samples): + self.partition[i] = i @njit(parallel=True, @@ -345,7 +362,7 @@ def find_node_split(context, sample_indices): # ctx.ordered_gradients[i] = ctx.gradients[samples_indices[i]] if sample_indices.shape[0] != ctx.gradients.shape[0]: starts, ends, n_threads = get_threads_chunks(n_samples) - if ctx.constant_hessian: + if ctx.hessian_is_constant: for thread_idx in prange(n_threads): for i in range(starts[thread_idx], ends[thread_idx]): ordered_gradients[i] = ctx.gradients[sample_indices[i]] @@ -356,7 +373,7 @@ def find_node_split(context, sample_indices): ordered_hessians[i] = ctx.hessians[sample_indices[i]] ctx.sum_gradients = ctx.ordered_gradients[:n_samples].sum() - if ctx.constant_hessian: + if ctx.hessian_is_constant: ctx.sum_hessians = ctx.constant_hessian_value * float32(n_samples) else: ctx.sum_hessians = ctx.ordered_hessians[:n_samples].sum() @@ -426,7 +443,7 @@ def find_node_split_subtraction(context, sample_indices, parent_histograms, sibling_histograms[0]['sum_gradients'].sum()) n_samples = sample_indices.shape[0] - if context.constant_hessian: + if context.hessian_is_constant: context.sum_hessians = \ context.constant_hessian_value * float32(n_samples) else: @@ -476,7 +493,7 @@ def _find_histogram_split(context, feature_idx, sample_indices): ordered_hessians = context.ordered_hessians[:n_samples] if root_node: - if context.constant_hessian: + if context.hessian_is_constant: histogram = _build_histogram_root_no_hessian( context.max_bins, X_binned, ordered_gradients) else: @@ -484,7 +501,7 @@ def _find_histogram_split(context, feature_idx, sample_indices): context.max_bins, X_binned, ordered_gradients, context.ordered_hessians) else: - if context.constant_hessian: + if context.hessian_is_constant: histogram = _build_histogram_no_hessian( context.max_bins, sample_indices, X_binned, ordered_gradients) @@ -537,7 +554,7 @@ def _find_best_bin_to_split_helper(context, feature_idx, histogram, n_samples): n_samples_left += histogram[bin_idx]['count'] n_samples_right = n_samples - n_samples_left - if context.constant_hessian: + if context.hessian_is_constant: hessian_left += (histogram[bin_idx]['count'] * context.constant_hessian_value) else: