Skip to content

Commit

Permalink
rename mixin (#81)
Browse files Browse the repository at this point in the history
* rename to base class

* move

* fix

* fmt

* mixin
  • Loading branch information
crflynn authored Apr 24, 2021
1 parent 16f7264 commit fa1674c
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 77 deletions.
37 changes: 36 additions & 1 deletion skranger/ensemble/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import bisect
import warnings
from collections.abc import Iterable

Expand All @@ -6,7 +7,7 @@
from sklearn.utils.validation import check_is_fitted


class RangerValidationMixin:
class RangerMixin:
@property
def feature_importances_(self):
try:
Expand All @@ -22,6 +23,38 @@ def feature_importances_(self):
"importance must be set to something other than 'none'"
) from None

def get_importance_pvalues(self):
"""Calculate p-values for variable importances.
Uses the fast method from Janitza et al. (2016).
"""
check_is_fitted(self)
if self.importance != "impurity_corrected":
raise ValueError(
"p-values can only be calculated with importance parameter set to 'impurity_corrected'"
)

vimp = np.array(self.ranger_forest_["variable_importance"])
m1 = vimp[vimp < 0]
m2 = vimp[vimp == 0]

if len(m1) == 0:
raise ValueError(
"No negative importance values found, cannot calculate p-values."
)
if len(m2) < 1:
vimp_dist = np.concatenate((m1, -m1))
else:
vimp_dist = np.concatenate((m1, -m1, m2))

vimp_dist.sort()
result = []
for i in range(len(vimp)):
result.append(bisect.bisect_left(vimp_dist, vimp[i]))
pval = 1 - np.array(result) / len(vimp_dist)
return pval

# region validation
def _validate_parameters(self, X, y, sample_weights):
"""Validate ranger parameters and set defaults."""
self.n_jobs_ = max(
Expand Down Expand Up @@ -212,3 +245,5 @@ def _check_inbag(self, sample_weights):
raise ValueError("Cannot use class sampling and inbag.")
if len(self.inbag) != self.n_estimators:
raise ValueError("Size of inbag must be equal to n_estimators.")

# endregion
38 changes: 2 additions & 36 deletions skranger/ensemble/ranger_forest_classifier.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Scikit-learn wrapper for ranger classification."""
import bisect

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
Expand All @@ -10,10 +8,10 @@
from sklearn.utils.validation import check_is_fitted

from skranger.ensemble import ranger
from skranger.ensemble.base import RangerValidationMixin
from skranger.ensemble.base import RangerMixin


class RangerForestClassifier(RangerValidationMixin, ClassifierMixin, BaseEstimator):
class RangerForestClassifier(RangerMixin, ClassifierMixin, BaseEstimator):
r"""Ranger Random Forest Probability/Classification implementation for sci-kit learn.
Provides a sklearn classifier interface to the Ranger C++ library using Cython.
Expand Down Expand Up @@ -318,38 +316,6 @@ def predict_log_proba(self, X):
proba = self.predict_proba(X)
return np.log(proba)

def get_importance_pvalues(self):
"""Calculate p-values for variable importance.
Uses the fast method from Janitza et al. (2016).
"""

check_is_fitted(self)
if self.importance != "impurity_corrected":
raise ValueError(
"p-values can only be calculated with importance parameter set to 'impurity_corrected'"
)

vimp = np.array(self.ranger_forest_["variable_importance"])
m1 = vimp[vimp < 0]
m2 = vimp[vimp == 0]

if len(m1) == 0:
raise ValueError(
"No negative importance values found, cannot calculate p-values."
)
if len(m2) < 1:
vimp_dist = np.concatenate((m1, -m1))
else:
vimp_dist = np.concatenate((m1, -m1, m2))

vimp_dist.sort()
result = []
for i in range(len(vimp)):
result.append(bisect.bisect_left(vimp_dist, vimp[i]))
pval = 1 - np.array(result) / len(vimp_dist)
return pval

def _more_tags(self):
return {
"_xfail_checks": {
Expand Down
38 changes: 2 additions & 36 deletions skranger/ensemble/ranger_forest_regressor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""Scikit-learn wrapper for ranger regression."""
import bisect

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin
Expand All @@ -9,10 +7,10 @@
from sklearn.utils.validation import check_is_fitted

from skranger.ensemble import ranger
from skranger.ensemble.base import RangerValidationMixin
from skranger.ensemble.base import RangerMixin


class RangerForestRegressor(RangerValidationMixin, RegressorMixin, BaseEstimator):
class RangerForestRegressor(RangerMixin, RegressorMixin, BaseEstimator):
r"""Ranger Random Forest Regression implementation for sci-kit learn.
Provides a sklearn regressor interface to the Ranger C++ library using Cython. The
Expand Down Expand Up @@ -396,38 +394,6 @@ def predict(self, X):
)
return np.array(result["predictions"])

def get_importance_pvalues(self):
"""Calculate p-values for variable importance.
Uses the fast method from Janitza et al. (2016).
"""

check_is_fitted(self)
if self.importance != "impurity_corrected":
raise ValueError(
"p-values can only be calculated with importance parameter set to 'impurity_corrected'"
)

vimp = np.array(self.ranger_forest_["variable_importance"])
m1 = vimp[vimp < 0]
m2 = vimp[vimp == 0]

if len(m1) == 0:
raise ValueError(
"No negative importance values found, cannot calculate p-values."
)
if len(m2) < 1:
vimp_dist = np.concatenate((m1, -m1))
else:
vimp_dist = np.concatenate((m1, -m1, m2))

vimp_dist.sort()
result = []
for i in range(len(vimp)):
result.append(bisect.bisect_left(vimp_dist, vimp[i]))
pval = 1 - np.array(result) / len(vimp_dist)
return pval

def _more_tags(self):
return {
"_xfail_checks": {
Expand Down
4 changes: 2 additions & 2 deletions skranger/ensemble/ranger_forest_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from sklearn.utils.validation import check_is_fitted

from skranger.ensemble import ranger
from skranger.ensemble.base import RangerValidationMixin
from skranger.ensemble.base import RangerMixin


class RangerForestSurvival(RangerValidationMixin, BaseEstimator):
class RangerForestSurvival(RangerMixin, BaseEstimator):
r"""Ranger Random Forest Survival implementation for sci-kit survival.
Provides a sksurv interface to the Ranger C++ library using Cython. The
Expand Down
19 changes: 19 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,25 @@ def boston_X_mod(mod):
)


@pytest.fixture
def lung_X_mod(lung_X, mod):
if mod == "none":
return _lung_X[["Age_in_years", "Karnofsky_score"]]
elif mod == "random":
np.random.seed(42)
return np.concatenate((lung_X, np.random.uniform(size=(lung_X.shape))), 1)
elif mod == "const":
np.random.seed(42)
return np.concatenate(
(
lung_X,
np.random.uniform(size=(lung_X.shape)),
np.zeros(shape=(lung_X.shape)),
),
1,
)


@pytest.fixture
def iris_y():
return _iris_y
Expand Down
2 changes: 1 addition & 1 deletion tests/ensemble/test_ranger_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def test_importance_pvalues(self, iris_X_mod, iris_y, importance, mod):

# Test error for no non-negative importance values
if mod == "none":
rfc.fit(iris_X_mod, iris_y)
with pytest.raises(ValueError):
rfc.fit(iris_X_mod, iris_y)
rfc.get_importance_pvalues()
return

Expand Down
2 changes: 1 addition & 1 deletion tests/ensemble/test_ranger_forest_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def test_importance_pvalues(self, boston_X_mod, boston_y, importance, mod):
# Test error for no non-negative importance values

if mod == "none":
rfc.fit(boston_X_mod, boston_y)
with pytest.raises(ValueError):
rfc.fit(boston_X_mod, boston_y)
rfc.get_importance_pvalues()
return

Expand Down
25 changes: 25 additions & 0 deletions tests/ensemble/test_ranger_forest_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,31 @@ def test_importance(
else:
assert rfs.importance_mode_ == 3

def test_importance_pvalues(self, lung_X_mod, lung_y, importance, mod):
rfs = RangerForestSurvival(importance=importance)
np.random.seed(42)

if importance not in ["none", "impurity", "impurity_corrected", "permutation"]:
with pytest.raises(ValueError):
rfs.fit(lung_X_mod, lung_y)
return

if not importance == "impurity_corrected":
rfs.fit(lung_X_mod, lung_y)
with pytest.raises(ValueError):
rfs.get_importance_pvalues()
return

# Test error for no non-negative importance values
if mod == "none":
rfs.fit(lung_X_mod, lung_y)
with pytest.raises(ValueError):
rfs.get_importance_pvalues()
return

rfs.fit(lung_X_mod, lung_y)
assert len(rfs.get_importance_pvalues()) == lung_X_mod.shape[1]

def test_mtry(self, lung_X, lung_y, mtry):
rfs = RangerForestSurvival(mtry=mtry)

Expand Down

0 comments on commit fa1674c

Please sign in to comment.