Skip to content

Commit

Permalink
[MRG] Renamed features_data and binned_features into X_binned (#58)
Browse files Browse the repository at this point in the history
* Renamed features_data and binned_features into X_binned

* Also renamed in the tests

* fixed typo
  • Loading branch information
NicolasHug committed Dec 6, 2018
1 parent 8284e12 commit 874d5ea
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 90 deletions.
34 changes: 16 additions & 18 deletions pygbm/grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class TreeGrower:
Parameters
----------
features_data : array-like of int, shape=(n_samples, n_features)
X_binned : array-like of int, shape=(n_samples, n_features)
The binned input samples. Must be Fortran-aligned.
gradients : array-like, shape=(n_samples,)
The gradients of each training sample. Those are the gradients of the
Expand Down Expand Up @@ -158,13 +158,12 @@ class TreeGrower:
The shrinkage parameter to apply to the leaves values, also known as
learning rate.
"""
def __init__(self, features_data, gradients, hessians,
max_leaf_nodes=None, max_depth=None, min_samples_leaf=20,
min_gain_to_split=0., max_bins=256, n_bins_per_feature=None,
l2_regularization=0., min_hessian_to_split=1e-3,
shrinkage=1.):
def __init__(self, X_binned, gradients, hessians, max_leaf_nodes=None,
max_depth=None, min_samples_leaf=20, min_gain_to_split=0.,
max_bins=256, n_bins_per_feature=None, l2_regularization=0.,
min_hessian_to_split=1e-3, shrinkage=1.):

self._validate_parameters(features_data, max_leaf_nodes, max_depth,
self._validate_parameters(X_binned, max_leaf_nodes, max_depth,
min_samples_leaf, min_gain_to_split,
l2_regularization, min_hessian_to_split)

Expand All @@ -173,18 +172,17 @@ def __init__(self, features_data, gradients, hessians,

if isinstance(n_bins_per_feature, int):
n_bins_per_feature = np.array(
[n_bins_per_feature] * features_data.shape[1],
[n_bins_per_feature] * X_binned.shape[1],
dtype=np.uint32)

self.splitting_context = SplittingContext(
features_data.shape[1], features_data, max_bins,
n_bins_per_feature, gradients, hessians,
l2_regularization, min_hessian_to_split, min_samples_leaf,
min_gain_to_split)
X_binned, max_bins, n_bins_per_feature, gradients,
hessians, l2_regularization, min_hessian_to_split,
min_samples_leaf, min_gain_to_split)
self.max_leaf_nodes = max_leaf_nodes
self.max_depth = max_depth
self.min_samples_leaf = min_samples_leaf
self.features_data = features_data
self.X_binned = X_binned
self.min_gain_to_split = min_gain_to_split
self.shrinkage = shrinkage
self.splittable_nodes = []
Expand All @@ -194,20 +192,20 @@ def __init__(self, features_data, gradients, hessians,
self._intilialize_root()
self.n_nodes = 1

def _validate_parameters(self, features_data, max_leaf_nodes, max_depth,
def _validate_parameters(self, X_binned, max_leaf_nodes, max_depth,
min_samples_leaf, min_gain_to_split,
l2_regularization, min_hessian_to_split):
"""Validate parameters passed to __init__.
Also validate parameters passed to SplittingContext because we cannot
raise exceptions in a jitclass.
"""
if features_data.dtype != np.uint8:
if X_binned.dtype != np.uint8:
raise NotImplementedError(
"Explicit feature binning required for now")
if not features_data.flags.f_contiguous:
if not X_binned.flags.f_contiguous:
raise ValueError(
"Binned data should be passed as Fortran contiguous "
"X_binned should be passed as Fortran contiguous "
"array for maximum efficiency.")
if max_leaf_nodes is not None and max_leaf_nodes < 1:
raise ValueError(f'max_leaf_nodes={max_leaf_nodes} should not be'
Expand Down Expand Up @@ -235,7 +233,7 @@ def grow(self):

def _intilialize_root(self):
"""Initialize root node and finalize it if needed."""
n_samples = self.features_data.shape[0]
n_samples = self.X_binned.shape[0]
depth = 0
if self.splitting_context.constant_hessian:
hessian = self.splitting_context.hessians[0] * n_samples
Expand Down
36 changes: 17 additions & 19 deletions pygbm/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, gain=-1., feature_idx=0, bin_idx=0,

@jitclass([
('n_features', uint32),
('binned_features', uint8[::1, :]),
('X_binned', uint8[::1, :]),
('max_bins', uint32),
('n_bins_per_feature', uint32[::1]),
('min_samples_leaf', uint32),
Expand Down Expand Up @@ -91,9 +91,7 @@ class SplittingContext:
Parameters
----------
n_features : int
The number of features.
binned_features : array of int
X_binned : array of int
The binned input samples. Must be Fortran-aligned.
max_bins : int, optional(default=256)
The maximum number of bins. Used to define the shape of the
Expand All @@ -119,13 +117,13 @@ class SplittingContext:
The minimum gain needed to split a node. Splits with lower gain will
be ignored.
"""
def __init__(self, n_features, binned_features, max_bins,
n_bins_per_feature, gradients, hessians,
l2_regularization, min_hessian_to_split=1e-3,
min_samples_leaf=20, min_gain_to_split=0.):
def __init__(self, X_binned, max_bins, n_bins_per_feature,
gradients, hessians, l2_regularization,
min_hessian_to_split=1e-3, min_samples_leaf=20,
min_gain_to_split=0.):

self.n_features = n_features
self.binned_features = binned_features
self.X_binned = X_binned
self.n_features = X_binned.shape[1]
# Note: all histograms will have <max_bins> bins, but some of the
# last bins may be unused if n_bins_per_feature[f] < max_bins
self.max_bins = max_bins
Expand Down Expand Up @@ -156,7 +154,7 @@ def __init__(self, n_features, binned_features, max_bins,
# partition = [cef|abdghijkl]
# we have 2 leaves, the left one is at position 0 and the second one at
# position 3. The order of the samples is irrelevant.
self.partition = np.arange(0, binned_features.shape[0], 1, np.uint32)
self.partition = np.arange(0, X_binned.shape[0], 1, np.uint32)
# buffers used in split_indices to support parallel splitting.
self.left_indices_buffer = np.empty_like(self.partition)
self.right_indices_buffer = np.empty_like(self.partition)
Expand Down Expand Up @@ -232,7 +230,7 @@ def split_indices(context, split_info, sample_indices):
# sample_indices for simplicity, but in reality they are of the same size
# as partition.

binned_feature = context.binned_features.T[split_info.feature_idx]
X_binned = context.X_binned.T[split_info.feature_idx]

n_threads = numba.config.NUMBA_DEFAULT_NUM_THREADS
n_samples = sample_indices.shape[0]
Expand Down Expand Up @@ -264,7 +262,7 @@ def split_indices(context, split_info, sample_indices):
stop = start + sizes[thread_idx]
for i in range(start, stop):
sample_idx = sample_indices[i]
if binned_feature[sample_idx] <= split_info.bin_idx:
if X_binned[sample_idx] <= split_info.bin_idx:
left_indices_buffer[start + left_count] = sample_idx
left_count += 1
else:
Expand Down Expand Up @@ -474,28 +472,28 @@ def _find_histogram_split(context, feature_idx, sample_indices):
Returns the best SplitInfo among all the possible bins of the feature.
"""
n_samples = sample_indices.shape[0]
binned_feature = context.binned_features.T[feature_idx]
X_binned = context.X_binned.T[feature_idx]

root_node = binned_feature.shape[0] == n_samples
root_node = X_binned.shape[0] == n_samples
ordered_gradients = context.ordered_gradients[:n_samples]
ordered_hessians = context.ordered_hessians[:n_samples]

if root_node:
if context.constant_hessian:
histogram = _build_histogram_root_no_hessian(
context.max_bins, binned_feature, ordered_gradients)
context.max_bins, X_binned, ordered_gradients)
else:
histogram = _build_histogram_root(
context.max_bins, binned_feature, ordered_gradients,
context.max_bins, X_binned, ordered_gradients,
context.ordered_hessians)
else:
if context.constant_hessian:
histogram = _build_histogram_no_hessian(
context.max_bins, sample_indices, binned_feature,
context.max_bins, sample_indices, X_binned,
ordered_gradients)
else:
histogram = _build_histogram(
context.max_bins, sample_indices, binned_feature,
context.max_bins, sample_indices, X_binned,
ordered_gradients, ordered_hessians)

return _find_best_bin_to_split_helper(context, feature_idx, histogram,
Expand Down
37 changes: 18 additions & 19 deletions tests/test_grower.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@ def _make_training_data(n_bins=256, constant_hessian=True):

# Generate some test data directly binned so as to test the grower code
# independently of the binning logic.
features_data = rng.randint(0, n_bins - 1, size=(n_samples, 2),
dtype=np.uint8)
features_data = np.asfortranarray(features_data)
X_binned = rng.randint(0, n_bins - 1, size=(n_samples, 2), dtype=np.uint8)
X_binned = np.asfortranarray(X_binned)

def true_decision_function(input_features):
"""Ground truth decision function
Expand All @@ -33,7 +32,7 @@ def true_decision_function(input_features):
else:
return 1

target = np.array([true_decision_function(x) for x in features_data],
target = np.array([true_decision_function(x) for x in X_binned],
dtype=np.float32)

# Assume a square loss applied to an initial model that always predicts 0
Expand All @@ -43,7 +42,7 @@ def true_decision_function(input_features):
all_hessians = np.ones(shape=1, dtype=np.float32)
else:
all_hessians = np.ones_like(all_gradients)
return features_data, all_gradients, all_hessians
return X_binned, all_gradients, all_hessians


def _check_children_consistency(parent, left, right):
Expand Down Expand Up @@ -76,16 +75,16 @@ def _check_children_consistency(parent, left, right):
]
)
def test_grow_tree(n_bins, constant_hessian, stopping_param, shrinkage):
features_data, all_gradients, all_hessians = _make_training_data(
X_binned, all_gradients, all_hessians = _make_training_data(
n_bins=n_bins, constant_hessian=constant_hessian)
n_samples = features_data.shape[0]
n_samples = X_binned.shape[0]

if stopping_param == "max_leaf_nodes":
stopping_param = {"max_leaf_nodes": 3}
else:
stopping_param = {"min_gain_to_split": 0.01}

grower = TreeGrower(features_data, all_gradients, all_hessians,
grower = TreeGrower(X_binned, all_gradients, all_hessians,
max_bins=n_bins, shrinkage=shrinkage,
min_samples_leaf=1, **stopping_param)

Expand Down Expand Up @@ -145,9 +144,9 @@ def test_grow_tree(n_bins, constant_hessian, stopping_param, shrinkage):
def test_predictor_from_grower():
# Build a tree on the toy 3-leaf dataset to extract the predictor.
n_bins = 256
features_data, all_gradients, all_hessians = _make_training_data(
X_binned, all_gradients, all_hessians = _make_training_data(
n_bins=n_bins)
grower = TreeGrower(features_data, all_gradients, all_hessians,
grower = TreeGrower(X_binned, all_gradients, all_hessians,
max_bins=n_bins, shrinkage=1.,
max_leaf_nodes=3, min_samples_leaf=5)
grower.grow()
Expand Down Expand Up @@ -178,7 +177,7 @@ def test_predictor_from_grower():
assert_array_almost_equal(predictions, expected_targets, decimal=5)

# Check that training set can be recovered exactly:
predictions = predictor.predict_binned(features_data)
predictions = predictor.predict_binned(X_binned)
assert_array_almost_equal(predictions, -all_gradients, decimal=5)


Expand Down Expand Up @@ -259,32 +258,32 @@ def test_min_samples_leaf_root(n_samples, min_samples_leaf):

def test_init_parameters_validation():

features_data, all_gradients, all_hessians = _make_training_data()
X_binned, all_gradients, all_hessians = _make_training_data()

features_data_float = features_data.astype(np.float32)
X_binned_float = X_binned.astype(np.float32)
assert_raises_regex(
NotImplementedError,
"Explicit feature binning required for now",
TreeGrower, features_data_float, all_gradients, all_hessians
TreeGrower, X_binned_float, all_gradients, all_hessians
)

features_data_C_array = np.ascontiguousarray(features_data)
X_binned_C_array = np.ascontiguousarray(X_binned)
assert_raises_regex(
ValueError,
"Binned data should be passed as Fortran contiguous array",
TreeGrower, features_data_C_array, all_gradients, all_hessians
"X_binned should be passed as Fortran contiguous array",
TreeGrower, X_binned_C_array, all_gradients, all_hessians
)

assert_raises_regex(
ValueError,
"min_gain_to_split=-1 must be positive",
TreeGrower, features_data, all_gradients, all_hessians,
TreeGrower, X_binned, all_gradients, all_hessians,
min_gain_to_split=-1
)

assert_raises_regex(
ValueError,
"min_hessian_to_split=-1 must be positive",
TreeGrower, features_data, all_gradients, all_hessians,
TreeGrower, X_binned, all_gradients, all_hessians,
min_hessian_to_split=-1
)

0 comments on commit 874d5ea

Please sign in to comment.