Skip to content

Commit

Permalink
[MRG] More docstrings (#67)
Browse files Browse the repository at this point in the history
Also added 2 separated toc-tree:

- Estimator API 
- Private API

With warnings everywhere about how the API and default values are subject to change without notice.

Also made some functions private (`_inverse_link_function`, `_find_binning_thresholds`)
  • Loading branch information
NicolasHug committed Dec 13, 2018
1 parent 8c43e40 commit e28761b
Show file tree
Hide file tree
Showing 13 changed files with 285 additions and 138 deletions.
48 changes: 18 additions & 30 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,29 @@
Welcome to pygbm's documentation!
=================================

.. .. toctree::
.. :maxdepth: 2
.. :caption: API Reference
.. :hidden:
.. warning::
Pygbm's API and default values are likely to be changed in future
version, without any deprecation cycle.

Gradient Boosting Estimators
============================

.. automodule:: pygbm.gradient_boosting
:members:
:exclude-members: BaseGradientBoostingMachine
.. toctree::
:maxdepth: 2
:caption: Estimator API
:hidden:

Grower
======
public_api

.. automodule:: pygbm.grower
:members:
.. toctree::
:maxdepth: 2
:caption: Private API
:hidden:

Splitting
=========
private_api

.. automodule:: pygbm.splitting
:members:

Binning
=======
Indices and tables
==================

.. automodule:: pygbm.binning
:members:



.. Indices and tables
.. ==================
.. * :ref:`genindex`
.. * :ref:`modindex`
.. * :ref:`search`
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
62 changes: 27 additions & 35 deletions pygbm/binning.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,31 @@
"""
This module contains the BinMapper class.
BinMapper is used for mapping a real-valued dataset into integer-valued bins
with equally-spaced thresholds.
"""
import numpy as np
from numba import njit, prange
from sklearn.utils import check_random_state, check_array
from sklearn.base import BaseEstimator, TransformerMixin


def find_binning_thresholds(data, max_bins=256, subsample=int(2e5),
random_state=None):
def _find_binning_thresholds(data, max_bins=256, subsample=int(2e5),
random_state=None):
"""Extract feature-wise equally-spaced quantiles from numerical data
Subsample the dataset if too large as the feature-wise quantiles
should be stable.
If the number of unique values for a given feature is less than
``max_bins``, then the unique values are used instead of the quantiles.
Parameters
----------
data: array-like, shape=(n_samples, n_features)
The numerical dataset to analyse.
max_bins: int, optional (default=256)
The number of bins to extract for each feature. As we code the binned
values as 8-bit integers, max_bins should be no larger than 256.
subsample: int, optional (default=2e5)
Number of random subsamples to consider to compute the quantiles.
random_state: int or numpy.random.RandomState or None, \
optional (default=None)
Pseudo-random number generator to control the random sub-sampling.
Return
------
binning_thresholds: tuple of arrays
For each feature, stores the increasing numeric values that can
be used to separate the bins.
len(binning_thresholds) == n_features.
Each array has size ``(n_bins - 1)`` where:
``n_bins == min(max_bins, len(np.unique(data[:, feature_idx])))``
be used to separate the bins. len(binning_thresholds) == n_features.
"""
if not (2 <= max_bins <= 256):
raise ValueError(f'max_bins={max_bins} should be no smaller than 2 '
f'and no larger than 256.')
rng = check_random_state(random_state)
if data.shape[0] > subsample:
if subsample is not None and data.shape[0] > subsample:
subset = rng.choice(np.arange(data.shape[0]), subsample)
data = data[subset]
dtype = data.dtype
Expand Down Expand Up @@ -124,24 +106,33 @@ def _map_num_col_to_bins(data, binning_thresholds, binned):


class BinMapper(BaseEstimator, TransformerMixin):
"""Transformer that maps a dataset into integer-valued bins
"""Transformer that maps a dataset into integer-valued bins.
The bins are created in a feature-wise fashion, with equally-spaced
quantiles.
Large datasets are subsampled, but the feature-wise quantiles should
remain stable.
If the number of unique values for a given feature is less than
``max_bins``, then the unique values of this feature are used instead of
the quantiles.
Parameters
----------
max_bins : int, optional (default=256)
The maximum number of bins to use. If for a given feature the number of
unique values is less than ``max_bins``, then those unique values
will be used instead of the quantiles.
subsample : int, optional (default=1e5)
will be used to compute the bin thresholds, instead of the quantiles.
subsample : int or None, optional (default=1e5)
If ``n_samples > subsample``, then ``sub_samples`` samples will be
randomly choosen to compute the quantiles.
TODO: accept None?
randomly choosen to compute the quantiles. If ``None``, the whole data
is used.
random_state: int or numpy.random.RandomState or None, \
optional (default=None)
Pseudo-random number generator to control the random sub-sampling.
See `scikit-learn glossary
<https://scikit-learn.org/stable/glossary.html#term-random-state>`_.
"""
def __init__(self, max_bins=256, subsample=int(1e5), random_state=None):
self.max_bins = max_bins
Expand All @@ -161,7 +152,7 @@ def fit(self, X, y=None):
self : object
"""
X = check_array(X)
self.bin_thresholds_ = find_binning_thresholds(
self.bin_thresholds_ = _find_binning_thresholds(
X, self.max_bins, subsample=self.subsample,
random_state=self.random_state)

Expand All @@ -182,5 +173,6 @@ def transform(self, X):
Returns
-------
X_binned : array-like
The binned data"""
The binned data
"""
return _map_to_bins(X, binning_thresholds=self.bin_thresholds_)
60 changes: 41 additions & 19 deletions pygbm/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def fit(self, X, y):
# TODO: add support for mixed-typed (numerical + categorical) data
# TODO: add support for missing data
# TODO: add support for pre-binned data (pass-through)?
# TODO: test input checking
X, y = check_X_y(X, y, dtype=[np.float32, np.float64])
y = self._encode_y(y)
if X.shape[0] == 1 or X.shape[1] == 1:
Expand Down Expand Up @@ -410,21 +409,23 @@ class GradientBoostingRegressor(BaseGradientBoostingMachine, RegressorMixin):
----------
loss : {'least_squares'}, optional(default='least_squares')
The loss function to use in the boosting process.
learning_rate : float, optional(default=TODO)
learning_rate : float, optional(default=0.1)
The learning rate, also known as *shrinkage*. This is used as a
multiplicative factor for the leaves values.
max_iter : int, optional(default=TODO)
multiplicative factor for the leaves values. Use ``1`` for no
shrinkage.
max_iter : int, optional(default=100)
The maximum number of iterations of the boosting process, i.e. the
maximum number of trees.
max_leaf_nodes : int, optional(default=TODO)
The maximum number of leaves for each tree.
max_depth : int, optional(default=TODO)
max_leaf_nodes : int or None, optional(default=None)
The maximum number of leaves for each tree. If None, there is no
maximum limit.
max_depth : int or None, optional(default=None)
The maximum depth of each tree. The depth of a tree is the number of
nodes to go from the root to the deepest leaf.
min_samples_leaf : int, optional(default=TODO)
min_samples_leaf : int, optional(default=20)
The minimum number of samples per leaf.
l2_regularization : float, optional(default=TODO)
The L2 regularization parameter.
l2_regularization : float, optional(default=0)
The L2 regularization parameter. Use 0 for no regularization.
max_bins : int, optional(default=256)
The maximum number of bins to use. Before training, each feature of
the input array ``X`` is binned into at most ``max_bins`` bins, which
Expand Down Expand Up @@ -458,6 +459,16 @@ class GradientBoostingRegressor(BaseGradientBoostingMachine, RegressorMixin):
is enabled. See
`scikit-learn glossary
<https://scikit-learn.org/stable/glossary.html#term-random-state>`_.
Examples
--------
>>> from sklearn.datasets import load_boston
>>> from pygbm import GradientBoostingRegressor
>>> X, y = load_boston(return_X_y=True)
>>> est = GradientBoostingRegressor().fit(X, y)
>>> est.score(X, y)
0.92...
"""

_VALID_LOSSES = ('least_squares',)
Expand Down Expand Up @@ -532,22 +543,24 @@ class GradientBoostingClassifier(BaseGradientBoostingMachine, ClassifierMixin):
generalizes to 'categorical_crossentropy' for multiclass
classification. 'auto' will automatically choose eiher loss depending
on the nature of the problem.
learning_rate : float, optional(default=TODO)
learning_rate : float, optional(default=1)
The learning rate, also known as *shrinkage*. This is used as a
multiplicative factor for the leaves values.
max_iter : int, optional(default=TODO)
multiplicative factor for the leaves values. Use ``1`` for no
shrinkage.
max_iter : int, optional(default=100)
The maximum number of iterations of the boosting process, i.e. the
maximum number of trees for binary classification. For multiclass
classification, `n_classes` trees per iteration are built.
max_leaf_nodes : int, optional(default=TODO)
The maximum number of leaves for each tree.
max_depth : int, optional(default=TODO)
max_leaf_nodes : int or None, optional(default=None)
The maximum number of leaves for each tree. If None, there is no
maximum limit.
max_depth : int or None, optional(default=None)
The maximum depth of each tree. The depth of a tree is the number of
nodes to go from the root to the deepest leaf.
min_samples_leaf : int, optional(default=TODO)
min_samples_leaf : int, optional(default=20)
The minimum number of samples per leaf.
l2_regularization : float, optional(default=TODO)
The L2 regularization parameter.
l2_regularization : float, optional(default=0)
The L2 regularization parameter. Use 0 for no regularization.
max_bins : int, optional(default=256)
The maximum number of bins to use. Before training, each feature of
the input array ``X`` is binned into at most ``max_bins`` bins, which
Expand Down Expand Up @@ -579,6 +592,15 @@ class GradientBoostingClassifier(BaseGradientBoostingMachine, ClassifierMixin):
binning process, and the train/validation data split if early stopping
is enabled. See `scikit-learn glossary
<https://scikit-learn.org/stable/glossary.html#term-random-state>`_.
Examples
--------
>>> from sklearn.datasets import load_iris
>>> from pygbm import GradientBoostingClassifier
>>> X, y = load_iris(return_X_y=True)
>>> clf = GradientBoostingClassifier().fit(X, y)
>>> clf.score(X, y)
0.97...
"""

_VALID_LOSSES = ('binary_crossentropy', 'categorical_crossentropy',
Expand Down
30 changes: 16 additions & 14 deletions pygbm/grower.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""
This module contains the TreeGrower class which builds a regression tree
fitting a Newton-Raphson step, based on the gradients and hessians of the
training data.
This module contains the TreeGrower class.
TreeGrowee builds a regression tree fitting a Newton-Raphson step, based on
the gradients and hessians of the training data.
"""
from heapq import heappush, heappop
import numpy as np
Expand All @@ -25,9 +26,9 @@ class TreeNode:
samples_indices : array of int
The indices of the samples at the node
sum_gradients : float
The sum of the gradients of the samples at the nodes
The sum of the gradients of the samples at the node
sum_hessians : float
The sum of the hessians of the samples at the nodes
The sum of the hessians of the samples at the node
parent : TreeNode or None, optional(default=None)
The parent of the node. None for root.
Expand All @@ -38,9 +39,9 @@ class TreeNode:
samples_indices : array of int
The indices of the samples at the node
sum_gradients : float
The sum of the gradients of the samples at the nodes
The sum of the gradients of the samples at the node
sum_hessians : float
The sum of the hessians of the samples at the nodes
The sum of the hessians of the samples at the node
parent : TreeNode or None, optional(default=None)
The parent of the node. None for root.
split_info : SplitInfo or None
Expand Down Expand Up @@ -130,12 +131,13 @@ class TreeGrower:
hessians : array-like, shape=(n_samples,)
The hessians of each training sample. Those are the hessians of the
loss w.r.t the predictions, evaluated at iteration ``i - 1``.
max_leaf_nodes : int, optional(default=TODO)
The maximum number of leaves for each tree.
max_depth : int, optional(default=TODO)
max_leaf_nodes : int or None, optional(default=None)
The maximum number of leaves for each tree. If None, there is no
maximum limit.
max_depth : int or None, optional(default=None)
The maximum depth of each tree. The depth of a tree is the number of
nodes to go from the root to the deepest leaf.
min_samples_leaf : int, optional(default=TODO)
min_samples_leaf : int, optional(default=20)
The minimum number of samples per leaf.
min_gain_to_split : float, optional(default=0.)
The minimum gain needed to split a node. Splits with lower gain will
Expand All @@ -148,13 +150,13 @@ class TreeGrower:
equal to ``max_bins``. If it's an int, all features are considered to
have the same number of bins. If None, all features are considered to
have ``max_bins`` bins.
l2_regularization : float, optional(default=TODO)
l2_regularization : float, optional(default=0)
The L2 regularization parameter.
min_hessian_to_split : float, optional(default=TODO)
min_hessian_to_split : float, optional(default=1e-3)
The minimum sum of hessians needed in each node. Splits that result in
at least one child having a sum of hessians less than
min_hessian_to_split are discarded.
shrinkage : float, optional(default=TODO)
shrinkage : float, optional(default=1)
The shrinkage parameter to apply to the leaves values, also known as
learning rate.
"""
Expand Down

0 comments on commit e28761b

Please sign in to comment.