Skip to content

Commit

Permalink
Attempt to bring in categorical support (#46)
Browse files Browse the repository at this point in the history
<!--
Thanks for contributing a pull request! Please ensure you have taken a
look at
the contribution guidelines:
https://github.com/scikit-learn/scikit-learn/blob/main/CONTRIBUTING.md
-->

#### Reference Issues/PRs
Helps bring in fork wrt changes in
scikit-learn#12866

#### What does this implement/fix? Explain your changes.


#### Any other comments?


<!--
Please be aware that we are a loose team of volunteers so patience is
necessary; assistance handling other issues is very welcome. We value
all user contributions, no matter how minor they are. If we are slow to
review, either the pull request needs some benchmarking, tinkering,
convincing, etc. or more likely the reviewers are simply busy. In either
case, we ask for your understanding during the review process.
For more information, see our FAQ on this topic:

http://scikit-learn.org/dev/faq.html#why-is-my-pull-request-not-getting-any-attention.

Thanks for contributing!
-->

---------

Signed-off-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
adam2392 committed Jul 20, 2023
1 parent e9d702b commit 9a614f4
Show file tree
Hide file tree
Showing 15 changed files with 1,462 additions and 100 deletions.
128 changes: 128 additions & 0 deletions benchmarks/bench_tree_nocats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from itertools import product
from timeit import timeit

import numpy as np
import pandas as pd

from sklearn.datasets import fetch_openml
from sklearn.ensemble import ExtraTreesClassifier, RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import OneHotEncoder


def get_data(trunc_ncat):
# the data is located here: https://www.openml.org/d/4135
X, y = fetch_openml(data_id=4135, return_X_y=True)
X = pd.DataFrame(X)

Xdicts = []
for trunc in trunc_ncat:
X_trunc = X % trunc if trunc > 0 else X
keep_idx = np.array(
[idx[0] for idx in X_trunc.groupby(list(X.columns)).groups.values()]
)
X_trunc = X_trunc.values[keep_idx]
y_trunc = y[keep_idx]

X_ohe = OneHotEncoder(categories="auto").fit_transform(X_trunc)

Xdicts.append({"X": X_trunc, "y": y_trunc, "ohe": False, "trunc": trunc})
Xdicts.append({"X": X_ohe, "y": y_trunc, "ohe": True, "trunc": trunc})

return Xdicts


# Training dataset
trunc_factor = [2, 3, 4, 5, 6, 8, 10, 12, 14, 16, 64, 0]
data = get_data(trunc_factor)
results = []
# Loop over classifiers and datasets
for Xydict, clf_type in product(data, [RandomForestClassifier, ExtraTreesClassifier]):
# Can't use non-truncated categorical data with RandomForest
# and it becomes intractable with too many categories
if (
clf_type is RandomForestClassifier
and not Xydict["ohe"]
and (not Xydict["trunc"] or Xydict["trunc"] > 16)
):
continue

X, y = Xydict["X"], Xydict["y"]
tech = "One-hot" if Xydict["ohe"] else "NOCATS"
trunc = "truncated({})".format(Xydict["trunc"]) if Xydict["trunc"] > 0 else "full"
cat = "none" if Xydict["ohe"] else "all"
cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=17).split(X, y)

traintimes = []
testtimes = []
aucs = []
name = "({}, {}, {})".format(clf_type.__name__, trunc, tech)

for train, test in cv:
# Train
clf = clf_type(
n_estimators=10,
max_features=None,
min_samples_leaf=1,
random_state=23,
bootstrap=False,
max_depth=None,
categorical=cat,
)

traintimes.append(
timeit(
"clf.fit(X[train], y[train])".format(),
"from __main__ import clf, X, y, train",
number=1,
)
)

"""
# Check that all leaf nodes are pure
for est in clf.estimators_:
leaves = est.tree_.children_left < 0
print(np.max(est.tree_.impurity[leaves]))
#assert(np.all(est.tree_.impurity[leaves] == 0))
"""

# Test
probs = []
testtimes.append(
timeit(
"probs.append(clf.predict_proba(X[test]))",
"from __main__ import probs, clf, X, test",
number=1,
)
)

aucs.append(roc_auc_score(y[test], probs[0][:, 1]))

traintimes = np.array(traintimes)
testtimes = np.array(testtimes)
aucs = np.array(aucs)
results.append(
[
name,
traintimes.mean(),
traintimes.std(),
testtimes.mean(),
testtimes.std(),
aucs.mean(),
aucs.std(),
]
)

results_df = pd.DataFrame(results)
results_df.columns = [
"name",
"train time mean",
"train time std",
"test time mean",
"test time std",
"auc mean",
"auc std",
]
results_df = results_df.set_index("name")
print(results_df)
97 changes: 74 additions & 23 deletions sklearn/ensemble/_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
----------
X : {array-like, sparse matrix} of shape (n_samples, n_features)
Input data.
quantiles : float, optional
quantiles : array-like, float, optional
The quantiles at which to evaluate, by default 0.5 (median).
method : str, optional
The method to interpolate, by default 'linear'. Can be any keyword
Expand All @@ -746,7 +746,7 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):
X = self._validate_X_predict(X)

if not isinstance(quantiles, (np.ndarray, list)):
quantiles = np.array([quantiles])
quantiles = np.atleast_1d(np.array(quantiles))

# if we trained a binning tree, then we should re-bin the data
# XXX: this is inefficient and should be improved to be in line with what
Expand Down Expand Up @@ -777,15 +777,15 @@ def predict_quantiles(self, X, quantiles=0.5, method="nearest"):

# (n_total_leaf_samples, n_outputs)
leaf_node_samples = np.vstack(
(
[
est.leaf_nodes_samples_[leaf_nodes[jdx]]
for jdx, est in enumerate(self.estimators_)
)
]
)

# get quantiles across all leaf node samples
y_hat[idx, ...] = np.quantile(
leaf_node_samples, quantiles, axis=0, interpolation=method
leaf_node_samples, quantiles, axis=0, method=method
)

if is_classifier(self):
Expand Down Expand Up @@ -1550,6 +1550,17 @@ class RandomForestClassifier(ForestClassifier):
.. versionadded:: 1.4
categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or `None`. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.
Attributes
----------
estimator_ : :class:`~sklearn.tree.DecisionTreeClassifier`
Expand Down Expand Up @@ -1693,6 +1704,7 @@ def __init__(
max_bins=None,
store_leaf_values=False,
monotonic_cst=None,
categorical=None,
):
super().__init__(
estimator=DecisionTreeClassifier(),
Expand All @@ -1710,6 +1722,7 @@ def __init__(
"ccp_alpha",
"store_leaf_values",
"monotonic_cst",
"categorical",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -1733,6 +1746,7 @@ def __init__(
self.min_impurity_decrease = min_impurity_decrease
self.monotonic_cst = monotonic_cst
self.ccp_alpha = ccp_alpha
self.categorical = categorical


class RandomForestRegressor(ForestRegressor):
Expand Down Expand Up @@ -1935,6 +1949,17 @@ class RandomForestRegressor(ForestRegressor):
.. versionadded:: 1.4
categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or `None`. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.
Attributes
----------
estimator_ : :class:`~sklearn.tree.DecisionTreeRegressor`
Expand Down Expand Up @@ -2065,6 +2090,7 @@ def __init__(
max_bins=None,
store_leaf_values=False,
monotonic_cst=None,
categorical=None,
):
super().__init__(
estimator=DecisionTreeRegressor(),
Expand All @@ -2082,6 +2108,7 @@ def __init__(
"ccp_alpha",
"store_leaf_values",
"monotonic_cst",
"categorical",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -2104,6 +2131,7 @@ def __init__(
self.min_impurity_decrease = min_impurity_decrease
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.categorical = categorical


class ExtraTreesClassifier(ForestClassifier):
Expand Down Expand Up @@ -2316,24 +2344,16 @@ class ExtraTreesClassifier(ForestClassifier):
.. versionadded:: 1.4
monotonic_cst : array-like of int of shape (n_features), default=None
Indicates the monotonicity constraint to enforce on each feature.
- 1: monotonically increasing
- 0: no constraint
- -1: monotonically decreasing
If monotonic_cst is None, no constraints are applied.
Monotonicity constraints are not supported for:
- multiclass classifications (i.e. when `n_classes > 2`),
- multioutput classifications (i.e. when `n_outputs_ > 1`),
- classifications trained on data with missing values.
The constraints hold over the probability of the positive class.
Read more in the :ref:`User Guide <monotonic_cst_gbdt>`.
.. versionadded:: 1.4
categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or `None`. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.
Attributes
----------
Expand Down Expand Up @@ -2467,6 +2487,7 @@ def __init__(
max_bins=None,
store_leaf_values=False,
monotonic_cst=None,
categorical=None,
):
super().__init__(
estimator=ExtraTreeClassifier(),
Expand All @@ -2484,6 +2505,7 @@ def __init__(
"ccp_alpha",
"store_leaf_values",
"monotonic_cst",
"categorical",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -2507,6 +2529,7 @@ def __init__(
self.min_impurity_decrease = min_impurity_decrease
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.categorical = categorical


class ExtraTreesRegressor(ForestRegressor):
Expand Down Expand Up @@ -2704,6 +2727,17 @@ class ExtraTreesRegressor(ForestRegressor):
.. versionadded:: 1.4
categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or `None`. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.
Attributes
----------
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor`
Expand Down Expand Up @@ -2819,6 +2853,7 @@ def __init__(
max_bins=None,
store_leaf_values=False,
monotonic_cst=None,
categorical=None,
):
super().__init__(
estimator=ExtraTreeRegressor(),
Expand All @@ -2836,6 +2871,7 @@ def __init__(
"ccp_alpha",
"store_leaf_values",
"monotonic_cst",
"categorical",
),
bootstrap=bootstrap,
oob_score=oob_score,
Expand All @@ -2858,6 +2894,7 @@ def __init__(
self.min_impurity_decrease = min_impurity_decrease
self.ccp_alpha = ccp_alpha
self.monotonic_cst = monotonic_cst
self.categorical = categorical


class RandomTreesEmbedding(TransformerMixin, BaseForest):
Expand Down Expand Up @@ -2969,6 +3006,17 @@ class RandomTreesEmbedding(TransformerMixin, BaseForest):
new forest. See :term:`Glossary <warm_start>` and
:ref:`gradient_boosting_warm_start` for details.
categorical : array-like or str
Array of feature indices, boolean array of length n_features,
``'all'`` or `None`. Indicates which features should be
considered as categorical rather than ordinal. For decision trees,
the maximum number of categories is 64. In practice, the limit will
often be lower because the process of searching for the best possible
split grows exponentially with the number of categories. However, a
shortcut due to Breiman (1984) is used when fitting data with binary
labels using the ``Gini`` or ``Entropy`` criteria. In this case,
the runtime is linear in the number of categories.
Attributes
----------
estimator_ : :class:`~sklearn.tree.ExtraTreeRegressor` instance
Expand Down Expand Up @@ -3073,6 +3121,7 @@ def __init__(
verbose=0,
warm_start=False,
store_leaf_values=False,
categorical=None,
):
super().__init__(
estimator=ExtraTreeRegressor(),
Expand All @@ -3088,6 +3137,7 @@ def __init__(
"min_impurity_decrease",
"random_state",
"store_leaf_values",
"categorical",
),
bootstrap=False,
oob_score=False,
Expand All @@ -3106,6 +3156,7 @@ def __init__(
self.max_leaf_nodes = max_leaf_nodes
self.min_impurity_decrease = min_impurity_decrease
self.sparse_output = sparse_output
self.categorical = categorical

def _set_oob_score_and_attributes(self, X, y, scoring_function=None):
raise NotImplementedError("OOB score not supported by tree embedding")
Expand Down

0 comments on commit 9a614f4

Please sign in to comment.