Skip to content

Commit

Permalink
ordered_gradients and ordered_hessians now only allocated once in the
Browse files Browse the repository at this point in the history
splitting context
  • Loading branch information
NicolasHug committed Oct 31, 2018
1 parent 24a8e70 commit 1b073b1
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 28 deletions.
53 changes: 27 additions & 26 deletions pygbm/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(self, gain=0, feature_idx=0, bin_idx=0,
('binned_features', uint8[::1, :]),
('n_bins', uint32),
('all_gradients', float32[::1]),
('ordered_gradients', float32[::1]),
('all_hessians', float32[::1]),
('ordered_hessians', float32[::1]),
('constant_hessian', uint8),
('constant_hessian_value', float32),
('l2_regularization', float32),
Expand All @@ -55,7 +57,9 @@ def __init__(self, n_features, binned_features, n_bins,
self.binned_features = binned_features
self.n_bins = n_bins
self.all_gradients = all_gradients
self.ordered_gradients = np.empty_like(all_gradients)
self.all_hessians = all_hessians
self.ordered_hessians = np.empty_like(all_hessians)
self.constant_hessian = all_hessians.shape[0] == 1
self.l2_regularization = l2_regularization
self.min_hessian_to_split = min_hessian_to_split
Expand Down Expand Up @@ -202,31 +206,30 @@ def find_node_split(context, sample_indices):
if sample_indices.shape[0] == context.all_gradients.shape[0]:
# Root node: the ordering of sample_indices and all_gradients
# are expected to be consistent in this case.
ordered_gradients = context.all_gradients
ordered_hessians = context.all_hessians
context.ordered_gradients = context.all_gradients
context.ordered_hessians = context.all_hessians
else:
ordered_gradients = np.empty_like(sample_indices, dtype=loss_dtype)
context.ordered_gradients = np.empty_like(sample_indices, dtype=loss_dtype)
if context.constant_hessian:
ordered_hessians = context.all_hessians
context.ordered_hessians = context.all_hessians
for i, sample_idx in enumerate(sample_indices):
ordered_gradients[i] = context.all_gradients[sample_idx]
context.ordered_gradients[i] = context.all_gradients[sample_idx]
else:
ordered_hessians = np.empty_like(sample_indices,
context.ordered_hessians = np.empty_like(sample_indices,
dtype=loss_dtype)
for i, sample_idx in enumerate(sample_indices):
ordered_gradients[i] = context.all_gradients[sample_idx]
ordered_hessians[i] = context.all_hessians[sample_idx]
context.ordered_gradients[i] = context.all_gradients[sample_idx]
context.ordered_hessians[i] = context.all_hessians[sample_idx]

sum_gradients = ordered_gradients.sum()
sum_gradients = context.ordered_gradients.sum()
if context.constant_hessian:
n_samples = sample_indices.shape[0]
sum_hessians = context.constant_hessian_value * float32(n_samples)
else:
sum_hessians = ordered_hessians.sum()
sum_hessians = context.ordered_hessians.sum()

return _parallel_find_split(
context, sample_indices, ordered_gradients, ordered_hessians,
sum_gradients, sum_hessians)
return _parallel_find_split(context, sample_indices, sum_gradients,
sum_hessians)


@njit()
Expand Down Expand Up @@ -270,8 +273,7 @@ def _find_best_feature_to_split_helper(n_features, n_bins, split_infos):


@njit(parallel=True)
def _parallel_find_split(splitter, sample_indices, ordered_gradients,
ordered_hessians, sum_gradients, sum_hessians):
def _parallel_find_split(context, sample_indices, sum_gradients, sum_hessians):
"""For each feature, find the best bin to split on by scanning data.
This is done by calling _find_histogram_split that compute histograms
Expand All @@ -284,15 +286,15 @@ def _parallel_find_split(splitter, sample_indices, ordered_gradients,
# Pre-allocate the results datastructure to be able to use prange:
# numba jitclass do not seem to properly support default values for kwargs.
split_infos = [SplitInfo(0, 0, 0, 0., 0., 0., 0.)
for i in range(splitter.n_features)]
for feature_idx in prange(splitter.n_features):
for i in range(context.n_features)]
for feature_idx in prange(context.n_features):
split_info = _find_histogram_split(
splitter, feature_idx, sample_indices,
ordered_gradients, ordered_hessians, sum_gradients, sum_hessians)
context, feature_idx, sample_indices,
sum_gradients, sum_hessians)
split_infos[feature_idx] = split_info

return _find_best_feature_to_split_helper(
splitter.n_features, splitter.n_bins, split_infos)
context.n_features, context.n_bins, split_infos)


@njit(parallel=True)
Expand Down Expand Up @@ -324,7 +326,6 @@ def _parallel_find_split_subtraction(context, parent_histograms,

@njit(fastmath=True)
def _find_histogram_split(context, feature_idx, sample_indices,
ordered_gradients, ordered_hessians,
sum_gradients, sum_hessians):
"""Compute the histogram for a given feature and return the best bin."""
binned_feature = context.binned_features.T[feature_idx]
Expand All @@ -333,20 +334,20 @@ def _find_histogram_split(context, feature_idx, sample_indices,
if root_node:
if context.constant_hessian:
histogram = _build_histogram_root_no_hessian(
context.n_bins, binned_feature, ordered_gradients)
context.n_bins, binned_feature, context.ordered_gradients)
else:
histogram = _build_histogram_root(
context.n_bins, binned_feature, ordered_gradients,
ordered_hessians)
context.n_bins, binned_feature, context.ordered_gradients,
context.ordered_hessians)
else:
if context.constant_hessian:
histogram = _build_histogram_no_hessian(
context.n_bins, sample_indices, binned_feature,
ordered_gradients)
context.ordered_gradients)
else:
histogram = _build_histogram(
context.n_bins, sample_indices, binned_feature,
ordered_gradients, ordered_hessians)
context.ordered_gradients, context.ordered_hessians)

return _find_best_bin_to_split_helper(
context, feature_idx, histogram, sum_gradients, sum_hessians)
Expand Down
6 changes: 4 additions & 2 deletions tests/test_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def test_histogram_split(n_bins):
all_gradients, all_hessians,
l2_regularization,
min_hessian_to_split)
context.ordered_gradients = ordered_gradients
context.ordered_hessians = ordered_hessians

split_info = _find_histogram_split(
context, feature_idx, sample_indices, ordered_gradients,
ordered_hessians, ordered_gradients.sum(),
context, feature_idx, sample_indices,
ordered_gradients.sum(),
ordered_hessians.sum())

assert split_info.bin_idx == true_bin
Expand Down

0 comments on commit 1b073b1

Please sign in to comment.