From d90612d1805051ede425b0cac7650ee9ba59c6b3 Mon Sep 17 00:00:00 2001 From: Flynn Date: Mon, 18 Jan 2021 01:03:11 -0500 Subject: [PATCH] dont pass sample fraction at predict (#62) --- skranger/ensemble/ranger_forest_classifier.py | 2 +- skranger/ensemble/ranger_forest_regressor.py | 4 ++-- skranger/ensemble/ranger_forest_survival.py | 2 +- .../ensemble/test_ranger_forest_classifier.py | 24 +++++++++++++++++++ .../ensemble/test_ranger_forest_regressor.py | 10 ++++++++ tests/ensemble/test_ranger_forest_survival.py | 10 ++++++++ 6 files changed, 48 insertions(+), 4 deletions(-) diff --git a/skranger/ensemble/ranger_forest_classifier.py b/skranger/ensemble/ranger_forest_classifier.py index a323855..f147e99 100644 --- a/skranger/ensemble/ranger_forest_classifier.py +++ b/skranger/ensemble/ranger_forest_classifier.py @@ -269,7 +269,7 @@ def predict_proba(self, X): self.class_weights or [], False, # predict_all self.keep_inbag, - self.sample_fraction_, + [1], # sample_fraction 0.5, # alpha 0.1, # minprop self.holdout, diff --git a/skranger/ensemble/ranger_forest_regressor.py b/skranger/ensemble/ranger_forest_regressor.py index 29fb439..d02d29e 100644 --- a/skranger/ensemble/ranger_forest_regressor.py +++ b/skranger/ensemble/ranger_forest_regressor.py @@ -272,7 +272,7 @@ def _get_terminal_node_forest(self, X): [], # class_weights False, # predict_all self.keep_inbag, - self.sample_fraction_, + [1], # sample_fraction 0, # alpha 0, # minprop self.holdout, @@ -354,7 +354,7 @@ def predict(self, X): [], # class_weights False, # predict_all self.keep_inbag, - self.sample_fraction_, + [1], # sample_fraction self.alpha, self.minprop, self.holdout, diff --git a/skranger/ensemble/ranger_forest_survival.py b/skranger/ensemble/ranger_forest_survival.py index 15b35ec..ff2070b 100644 --- a/skranger/ensemble/ranger_forest_survival.py +++ b/skranger/ensemble/ranger_forest_survival.py @@ -255,7 +255,7 @@ def _predict(self, X): [], # class_weights False, # predict_all self.keep_inbag, - self.sample_fraction_, + [1], # sample_fraction self.alpha, self.minprop, self.holdout, diff --git a/tests/ensemble/test_ranger_forest_classifier.py b/tests/ensemble/test_ranger_forest_classifier.py index abd2230..d20cd54 100644 --- a/tests/ensemble/test_ranger_forest_classifier.py +++ b/tests/ensemble/test_ranger_forest_classifier.py @@ -36,18 +36,33 @@ def test_predict(self, iris_X, iris_y): pred = rfc.predict(iris_X) assert len(pred) == iris_X.shape[0] + # test with single record + iris_X_record = iris_X[0:1, :] + pred = rfc.predict(iris_X_record) + assert len(pred) == 1 + def test_predict_proba(self, iris_X, iris_y): rfc = RangerForestClassifier() rfc.fit(iris_X, iris_y) pred = rfc.predict_proba(iris_X) assert len(pred) == iris_X.shape[0] + # test with single record + iris_X_record = iris_X[0:1, :] + pred = rfc.predict_proba(iris_X_record) + assert len(pred) == 1 + def test_predict_log_proba(self, iris_X, iris_y): rfc = RangerForestClassifier() rfc.fit(iris_X, iris_y) pred = rfc.predict_log_proba(iris_X) assert len(pred) == iris_X.shape[0] + # test with single record + iris_X_record = iris_X[0:1, :] + pred = rfc.predict_log_proba(iris_X_record) + assert len(pred) == 1 + def test_serialize(self, iris_X, iris_y): tf = tempfile.TemporaryFile() rfc = RangerForestClassifier() @@ -145,6 +160,15 @@ def test_sample_fraction(self, iris_X, iris_y): rfc.fit(iris_X, iris_y) assert rfc.sample_fraction_ == [0.69] + # test with single record + iris_X_record = iris_X[0:1, :] + pred = rfc.predict(iris_X_record) + assert len(pred) == 1 + pred = rfc.predict_proba(iris_X_record) + assert len(pred) == 1 + pred = rfc.predict_log_proba(iris_X_record) + assert len(pred) == 1 + def test_sample_fraction_replace(self, iris_X, iris_y, replace): rfc = RangerForestClassifier(replace=replace) rfc.fit(iris_X, iris_y) diff --git a/tests/ensemble/test_ranger_forest_regressor.py b/tests/ensemble/test_ranger_forest_regressor.py index 4de7617..b86bbc2 100644 --- a/tests/ensemble/test_ranger_forest_regressor.py +++ b/tests/ensemble/test_ranger_forest_regressor.py @@ -31,6 +31,11 @@ def test_predict(self, boston_X, boston_y): pred = rfr.predict(boston_X) assert len(pred) == boston_X.shape[0] + # test with single record + boston_X_record = boston_X[0:1, :] + pred = rfr.predict(boston_X_record) + assert len(pred) == 1 + def test_serialize(self, boston_X, boston_y): tf = tempfile.TemporaryFile() rfr = RangerForestRegressor() @@ -125,6 +130,11 @@ def test_sample_fraction(self, iris_X, iris_y): rfr.fit(iris_X, iris_y) assert rfr.sample_fraction_ == [0.69] + # test with single record + iris_X_record = iris_X[0:1, :] + pred = rfr.predict(iris_X_record) + assert len(pred) == 1 + def test_sample_fraction_replace(self, boston_X, boston_y, replace): rfr = RangerForestRegressor(replace=replace) rfr.fit(boston_X, boston_y) diff --git a/tests/ensemble/test_ranger_forest_survival.py b/tests/ensemble/test_ranger_forest_survival.py index 79cf5b1..42c2dd3 100644 --- a/tests/ensemble/test_ranger_forest_survival.py +++ b/tests/ensemble/test_ranger_forest_survival.py @@ -34,6 +34,11 @@ def test_predict(self, lung_X, lung_y): pred = rfs.predict(lung_X) assert len(pred) == lung_X.shape[0] + # test with single record + lung_X_record = lung_X.values[0:1, :] + pred = rfs.predict(lung_X_record) + assert len(pred) == 1 + def test_predict_cumulative_hazard_function(self, lung_X, lung_y): rfs = RangerForestSurvival(n_estimators=N_ESTIMATORS) rfs.fit(lung_X, lung_y) @@ -140,6 +145,11 @@ def test_sample_fraction(self, lung_X, lung_y): rfs.fit(lung_X, lung_y) assert rfs.sample_fraction_ == [0.69] + # test with single record + lung_X_record = lung_X.values[0:1, :] + pred = rfs.predict(lung_X_record) + assert len(pred) == 1 + def test_sample_fraction_replace(self, lung_X, lung_y, replace): rfs = RangerForestSurvival(replace=replace) rfs.fit(lung_X, lung_y)