Skip to content

Commit

Permalink
[MRG] Initialize predictions to average value or class probabilities (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored and ogrisel committed Dec 14, 2018
1 parent f128df2 commit 9d97d9f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 26 deletions.
16 changes: 11 additions & 5 deletions pygbm/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,23 @@ def fit(self, X, y):
print("Fitting gradient boosted rounds:")

n_samples = X_binned_train.shape[0]
# values predicted by the trees. Used as-is in regression, and
# transformed into probas and / or classes for classification
self.baseline_prediction_ = self.loss_.get_baseline_prediction(
y_train, self.n_trees_per_iteration_)
# raw_predictions are the accumulated values predicted by the trees
# for the training data.
raw_predictions = np.zeros(
shape=(n_samples, self.n_trees_per_iteration_),
dtype=y_train.dtype
dtype=self.baseline_prediction_.dtype
)
raw_predictions += self.baseline_prediction_

# gradients and hessians are 1D arrays of size
# n_samples * n_trees_per_iteration
gradients, hessians = self.loss_.init_gradients_and_hessians(
n_samples=n_samples,
n_trees_per_iteration=self.n_trees_per_iteration_
prediction_dim=self.n_trees_per_iteration_
)

# predictors_ is a matrix of TreePredictor objects with shape
# (n_iter_, n_trees_per_iteration)
self.predictors_ = predictors = []
Expand Down Expand Up @@ -373,8 +378,9 @@ def _raw_predict(self, X, binned=False):
n_samples = X.shape[0]
raw_predictions = np.zeros(
shape=(n_samples, self.n_trees_per_iteration_),
dtype=np.float32
dtype=self.baseline_prediction_.dtype
)
raw_predictions += self.baseline_prediction_
# Should we parallelize this?
for predictors_of_ith_iteration in self.predictors_:
for k, predictor in enumerate(predictors_of_ith_iteration):
Expand Down
76 changes: 60 additions & 16 deletions pygbm/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _expit(x):
class BaseLoss(ABC):
"""Base class for a loss."""

def init_gradients_and_hessians(self, n_samples, n_trees_per_iteration):
def init_gradients_and_hessians(self, n_samples, prediction_dim):
"""Return initial gradients and hessians.
Unless hessians are constant, arrays are initialized with undefined
Expand All @@ -49,19 +49,20 @@ def init_gradients_and_hessians(self, n_samples, n_trees_per_iteration):
----------
n_samples : int
The number of samples passed to `fit()`
n_trees_per_iteration : int
The number of trees built at each iteration. Equals 1 for
regression and binary classification, or K where K is the number of
classes for multiclass classification.
prediction_dim : int
The dimension of a raw prediction, i.e. the number of trees
built at each iteration. Equals 1 for regression and binary
classification, or K where K is the number of classes for
multiclass classification.
Returns
-------
gradients : array-like, shape=(n_samples * n_trees_per_iteration)
hessians : array-like, shape=(n_samples * n_trees_per_iteration).
gradients : array-like, shape=(n_samples * prediction_dim)
hessians : array-like, shape=(n_samples * prediction_dim).
If hessians are constant (e.g. for ``LeastSquares`` loss, shape
is (1,) and the array is initialized to ``1``.
"""
shape = n_samples * n_trees_per_iteration
shape = n_samples * prediction_dim
gradients = np.empty(shape=shape, dtype=np.float32)
if self.hessian_is_constant:
hessians = np.ones(shape=1, dtype=np.float32)
Expand All @@ -70,6 +71,25 @@ def init_gradients_and_hessians(self, n_samples, n_trees_per_iteration):

return gradients, hessians

@abstractmethod
def get_baseline_prediction(self, y_train, prediction_dim):
"""Return initial predictions (before the first iteration).
Parameters
----------
y_train : array-like, shape=(n_samples,)
The target training values.
prediction_dim : int
The dimension of one prediction: 1 for binary classification and
regression, n_classes for multiclass classification.
Returns
-------
baseline_prediction: float or array of shape (1, prediction_dim)
The baseline prediction.
"""
pass

@abstractmethod
def update_gradients_and_hessians(self, gradients, hessians, y_true,
raw_predictions):
Expand All @@ -81,14 +101,14 @@ def update_gradients_and_hessians(self, gradients, hessians, y_true,
Parameters
----------
gradients : array-like, shape=(n_samples * n_trees_per_iteration)
gradients : array-like, shape=(n_samples * prediction_dim)
The gradients (treated as OUT array).
hessians : array-like, shape=(n_samples * n_trees_per_iteration) or \
hessians : array-like, shape=(n_samples * prediction_dim) or \
(1,)
The hessians (treated as OUT array).
y_true : array-like, shape=(n_samples,)
The true target values or each training sample.
raw_predictions : array-like, shape=(n_samples, n_trees_per_iteration)
raw_predictions : array-like, shape=(n_samples, prediction_dim)
The raw_predictions (i.e. values from the trees) of the tree
ensemble at iteration ``i - 1``.
"""
Expand All @@ -112,6 +132,9 @@ def __call__(self, y_true, raw_predictions, average=True):
loss = np.power(y_true - raw_predictions, 2)
return loss.mean() if average else loss

def get_baseline_prediction(self, y_train, prediction_dim):
return np.mean(y_train)

def inverse_link_function(self, raw_predictions):
return raw_predictions

Expand Down Expand Up @@ -158,6 +181,14 @@ def __call__(self, y_true, raw_predictions, average=True):
loss = np.logaddexp(0, raw_predictions) - y_true * raw_predictions
return loss.mean() if average else loss

def get_baseline_prediction(self, y_train, prediction_dim):
proba_positive_class = np.mean(y_train)
eps = np.finfo(y_train.dtype).eps
proba_positive_class = np.clip(proba_positive_class, eps, 1 - eps)
# log(x / 1 - x) is the anti function of sigmoid, or the link function
# of the Binomial model.
return np.log(proba_positive_class / (1 - proba_positive_class))

def update_gradients_and_hessians(self, gradients, hessians, y_true,
raw_predictions):
return _update_gradients_hessians_binary_crossentropy(
Expand Down Expand Up @@ -204,13 +235,26 @@ class CategoricalCrossEntropy(BaseLoss):

def __call__(self, y_true, raw_predictions, average=True):
one_hot_true = np.zeros_like(raw_predictions)
n_trees_per_iteration = raw_predictions.shape[1]
for k in range(n_trees_per_iteration):
prediction_dim = raw_predictions.shape[1]
for k in range(prediction_dim):
one_hot_true[:, k] = (y_true == k)

return (logsumexp(raw_predictions, axis=1) -
(one_hot_true * raw_predictions).sum(axis=1))

def get_baseline_prediction(self, y_train, prediction_dim):
init_value = np.zeros(
shape=(1, prediction_dim),
dtype=np.float32
)
eps = np.finfo(y_train.dtype).eps
for k in range(prediction_dim):
proba_kth_class = np.mean(y_train == k)
proba_kth_class = np.clip(proba_kth_class, eps, 1 - eps)
init_value[:, k] += np.log(proba_kth_class)

return init_value

def update_gradients_and_hessians(self, gradients, hessians, y_true,
raw_predictions):
return _update_gradients_hessians_categorical_crossentropy(
Expand All @@ -227,7 +271,7 @@ def predict_proba(self, raw_predictions):
def _update_gradients_hessians_categorical_crossentropy(
gradients, hessians, y_true, raw_predictions):
# Here gradients and hessians are of shape
# (n_samples * n_trees_per_iteration,).
# (n_samples * prediction_dim,).
# y_true is of shape (n_samples,).
# raw_predictions is of shape (n_samples, raw_predictions)
#
Expand All @@ -238,9 +282,9 @@ def _update_gradients_hessians_categorical_crossentropy(
# That would however require to pass a copy of raw_predictions, so it does
# not get partially overwritten at the end of the loop when
# _update_y_pred() is called (see sklearn PR 12715)
n_samples, n_trees_per_iteration = raw_predictions.shape
n_samples, prediction_dim = raw_predictions.shape
starts, ends, n_threads = get_threads_chunks(total_size=n_samples)
for k in range(n_trees_per_iteration):
for k in range(prediction_dim):
gradients_at_k = gradients[n_samples * k:n_samples * (k + 1)]
hessians_at_k = hessians[n_samples * k:n_samples * (k + 1)]
for thread_idx in prange(n_threads):
Expand Down
3 changes: 2 additions & 1 deletion pygbm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_lightgbm_estimator(pygbm_estimator):
'min_data_in_bin': 1,
'min_sum_hessian_in_leaf': 1e-3,
'min_gain_to_split': 0,
'verbosity': 10 if pygbm_params['verbose'] else 0
'verbosity': 10 if pygbm_params['verbose'] else 0,
'boost_from_average': True,
}
# TODO: change hardcoded values when / if they're arguments to the
# estimator.
Expand Down
2 changes: 1 addition & 1 deletion tests/test_compare_lightgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_same_predictions_regression(seed, min_samples_leaf, n_samples,
pred_lgbm = est_lightgbm.predict(X_train)
pred_pygbm = est_pygbm.predict(X_train)
# less than 1% of the predictions are different up to the 3rd decimal
assert np.mean(abs(pred_lgbm - pred_pygbm) > 1e-3) < .01
assert np.mean(abs(pred_lgbm - pred_pygbm) > 1e-3) < .011

if max_leaf_nodes < 10 and n_samples >= 1000:
pred_lgbm = est_lightgbm.predict(X_test)
Expand Down
62 changes: 59 additions & 3 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numpy.testing import assert_almost_equal
from scipy.optimize import newton
from scipy.special import logsumexp
from sklearn.utils import assert_all_finite
import pytest

from pygbm.loss import _LOSSES
Expand Down Expand Up @@ -80,12 +81,12 @@ def fprime2(x):
assert np.allclose(get_gradients(y_true, optimum), 0)


@pytest.mark.parametrize('loss, n_classes, n_trees_per_iteration', [
@pytest.mark.parametrize('loss, n_classes, prediction_dim', [
('least_squares', 0, 1),
('binary_crossentropy', 2, 1),
('categorical_crossentropy', 3, 3),
])
def test_numerical_gradients(loss, n_classes, n_trees_per_iteration):
def test_numerical_gradients(loss, n_classes, prediction_dim):
# Make sure gradients and hessians computed in the loss are correct, by
# comparing with their approximations computed with finite central
# differences.
Expand All @@ -98,7 +99,7 @@ def test_numerical_gradients(loss, n_classes, n_trees_per_iteration):
else:
y_true = rng.randint(0, n_classes, size=n_samples).astype(np.float64)
raw_predictions = rng.normal(
size=(n_samples, n_trees_per_iteration)
size=(n_samples, prediction_dim)
).astype(np.float64)
loss = _LOSSES[loss]()
get_gradients, get_hessians = get_derivatives_helper(loss)
Expand Down Expand Up @@ -154,3 +155,58 @@ def test_logsumexp():
b = np.full(n, 10000, dtype='float64')
desired = 10000.0 + np.log(n)
assert_almost_equal(_logsumexp(b), desired)


def test_baseline_least_squares():
rng = np.random.RandomState(0)

loss = _LOSSES['least_squares']()
y_train = rng.normal(size=100)
baseline_prediction = loss.get_baseline_prediction(y_train, 1)
assert baseline_prediction.shape == tuple() # scalar
# Make sure baseline prediction is the mean of all targets
assert_almost_equal(baseline_prediction, y_train.mean())


def test_baseline_binary_crossentropy():
rng = np.random.RandomState(0)

loss = _LOSSES['binary_crossentropy']()
for y_train in (np.zeros(shape=100), np.ones(shape=100)):
y_train = y_train.astype(np.float32)
baseline_prediction = loss.get_baseline_prediction(y_train, 1)
assert_all_finite(baseline_prediction)
assert_almost_equal(loss.inverse_link_function(baseline_prediction),
y_train[0])

# Make sure baseline prediction is equal to link_function(p), where p
# is the proba of the positive class. We want predict_proba() to return p,
# and by definition
# p = inverse_link_function(raw_prediction) = sigmoid(raw_prediction)
# So we want raw_prediction = link_function(p) = log(p / (1 - p))
y_train = rng.randint(0, 2, size=100).astype(np.float32)
baseline_prediction = loss.get_baseline_prediction(y_train, 1)
assert baseline_prediction.shape == tuple() # scalar
p = y_train.mean()
assert_almost_equal(baseline_prediction, np.log(p / (1 - p)))


def test_baseline_categorical_crossentropy():
rng = np.random.RandomState(0)

prediction_dim = 4
loss = _LOSSES['categorical_crossentropy']()
for y_train in (np.zeros(shape=100), np.ones(shape=100)):
y_train = y_train.astype(np.float32)
baseline_prediction = loss.get_baseline_prediction(y_train,
prediction_dim)
assert_all_finite(baseline_prediction)

# Same logic as for above test. Here inverse_link_function = softmax and
# link_function = log
y_train = rng.randint(0, prediction_dim + 1, size=100).astype(np.float32)
baseline_prediction = loss.get_baseline_prediction(y_train, prediction_dim)
assert baseline_prediction.shape == (1, prediction_dim)
for k in range(prediction_dim):
p = (y_train == k).mean()
assert_almost_equal(baseline_prediction[:, k], np.log(p))

0 comments on commit 9d97d9f

Please sign in to comment.