-
Notifications
You must be signed in to change notification settings - Fork 32
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MRG] Add support for multiclass classification (#49)
Adds CategoricalCrossEntropy loss to support multiclass classification. K trees per iteration are built instead of just one. * First draft for pluggable losses * Added basic test for losses * Addressed comments * Removed use of LeastSquares loss in tests/test_predictor.py * WIP, doesnt work * fixed gradient and hessian of ls * fixed y_pred shape * L2 gradient and hessian again * removed unnecessary convertions * Partially fixed logloss * Added prediction for classifiers, quick hack * Strict comparison to min_gain_to_split * lots of debug stuff (will be cleaned up) * udated grower * set lower decimal in tests * Added test * Changed condition to <= 0 and updated note * print stuff, to be removed * debugging, ignore * n_bins is now feature dependent * addressed comments, changed default max_bins to 256, tests arent robust now :( * Fixed n_bins = n_threhsolds + 1... Also corrected some tests * Added test for n_bins_per_feature * removed optional for min_smaples_leaf * set min_samples_leaf to 1 if None * Probably fixed overflow issue... Need to update stats BEFORE any 'continue' of course * higgs boson benchmark now uses classifiers * Parametrized test_boston_dataset after bugfix #43 * None is now invalid for min_samples_leaf * removed trailing whitespace * set default to 20 in SplittingContext * Used sklearn gradient computation to avoid overflow * Consolidated compare_lightgbm test * Some cleaning * pep8 * Added subclasses Regressor and Classifier * added check for losses * predict() is now custom in child classes * Higgs Boson benchmark now uses classifiers and accuracy * changed log_loss to logistic * Added test for losses * should fix test * cosmetics * Use get_gradients_and_hessians * Comment * Put back ROC AUC * comment * Parallelized gradients and hessians updates * Comment * First draft, looks like it's working OK. Needs lots of cleaning and testing * Removed get_gradients methods and added helper in tests * some pep8 * Added sanity check for multinomial loss * Put predict_proba() into loss functions * Fixed predict_binned * Updated gradient and hessians before loop over k * Minor changes * Fixed verbose printing * switch to multinomial if problem is multiclass if loss=logistic * parallelized multinomial loss * Temporary fix for numba update issue * cosmetics and comments * Fixed __call__ and added gradient test * minor modif * Added auto loss for classification * fix for auto loss * some comments * Fixed plotting * Updated notes * Added custom logsumexp jitted function * Apply shrinkage at all iterations including first one * renamed all_gradients and all_hessians * Addressed comments * Classes are now label-encoded to support str targets and other inputs * Changed name _validate_y to _encode_y * Addressed comments * Renaming * y_pred now 2D array * raw_predict now returns 2D array * renamed y_pred into raw_predictions * pep8 and comments
- Loading branch information
1 parent
c5ebdfc
commit 8284e12
Showing
12 changed files
with
531 additions
and
147 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
Notes about multiclass classification | ||
===================================== | ||
|
||
Adding support for multiclass classification involves a few changes. The | ||
main one is that instead of building 1 tree per iteration (like in binary | ||
classification and regression), we build K trees per iteration, where K is | ||
the number of classes. | ||
|
||
Each tree is a kind of OVR tree, but trees are not completely independent | ||
because they influence each others when the gradients and hessians are | ||
updated in CategoricalCrossEntropy.update_gradients_and_hessians(). | ||
Concretely, the K trees of the ith iteration do not depend on each other, | ||
but each tree at iteration i depends on *all* the K trees of iteration i - | ||
1. | ||
|
||
For a given sample, the probability that it belongs to class k is computed | ||
as a regular softmax between scores = [scores_0, scores_1, ... scores_K-1] | ||
where scores_k = sum(<leaf value of kth tree of iteration i> | ||
for i in range(n_iterations)). | ||
The predicted class is then the argmax of the K probabilities. | ||
|
||
Regarding implementation details, the arrays gradients and hessians (for | ||
non-constant hessians) are now 1D arrays of size (n_samples * | ||
n_trees_per_iteration), instead of just (n_samples). raw_predictions is now | ||
a 2D array of shape (n_samples, n_trees_per_iteration) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.