Skip to content

Commit

Permalink
dont pass sample fraction at predict (#62)
Browse files Browse the repository at this point in the history
  • Loading branch information
crflynn authored Jan 18, 2021
1 parent d409bd8 commit d90612d
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 4 deletions.
2 changes: 1 addition & 1 deletion skranger/ensemble/ranger_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def predict_proba(self, X):
self.class_weights or [],
False, # predict_all
self.keep_inbag,
self.sample_fraction_,
[1], # sample_fraction
0.5, # alpha
0.1, # minprop
self.holdout,
Expand Down
4 changes: 2 additions & 2 deletions skranger/ensemble/ranger_forest_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _get_terminal_node_forest(self, X):
[], # class_weights
False, # predict_all
self.keep_inbag,
self.sample_fraction_,
[1], # sample_fraction
0, # alpha
0, # minprop
self.holdout,
Expand Down Expand Up @@ -354,7 +354,7 @@ def predict(self, X):
[], # class_weights
False, # predict_all
self.keep_inbag,
self.sample_fraction_,
[1], # sample_fraction
self.alpha,
self.minprop,
self.holdout,
Expand Down
2 changes: 1 addition & 1 deletion skranger/ensemble/ranger_forest_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _predict(self, X):
[], # class_weights
False, # predict_all
self.keep_inbag,
self.sample_fraction_,
[1], # sample_fraction
self.alpha,
self.minprop,
self.holdout,
Expand Down
24 changes: 24 additions & 0 deletions tests/ensemble/test_ranger_forest_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,18 +36,33 @@ def test_predict(self, iris_X, iris_y):
pred = rfc.predict(iris_X)
assert len(pred) == iris_X.shape[0]

# test with single record
iris_X_record = iris_X[0:1, :]
pred = rfc.predict(iris_X_record)
assert len(pred) == 1

def test_predict_proba(self, iris_X, iris_y):
rfc = RangerForestClassifier()
rfc.fit(iris_X, iris_y)
pred = rfc.predict_proba(iris_X)
assert len(pred) == iris_X.shape[0]

# test with single record
iris_X_record = iris_X[0:1, :]
pred = rfc.predict_proba(iris_X_record)
assert len(pred) == 1

def test_predict_log_proba(self, iris_X, iris_y):
rfc = RangerForestClassifier()
rfc.fit(iris_X, iris_y)
pred = rfc.predict_log_proba(iris_X)
assert len(pred) == iris_X.shape[0]

# test with single record
iris_X_record = iris_X[0:1, :]
pred = rfc.predict_log_proba(iris_X_record)
assert len(pred) == 1

def test_serialize(self, iris_X, iris_y):
tf = tempfile.TemporaryFile()
rfc = RangerForestClassifier()
Expand Down Expand Up @@ -145,6 +160,15 @@ def test_sample_fraction(self, iris_X, iris_y):
rfc.fit(iris_X, iris_y)
assert rfc.sample_fraction_ == [0.69]

# test with single record
iris_X_record = iris_X[0:1, :]
pred = rfc.predict(iris_X_record)
assert len(pred) == 1
pred = rfc.predict_proba(iris_X_record)
assert len(pred) == 1
pred = rfc.predict_log_proba(iris_X_record)
assert len(pred) == 1

def test_sample_fraction_replace(self, iris_X, iris_y, replace):
rfc = RangerForestClassifier(replace=replace)
rfc.fit(iris_X, iris_y)
Expand Down
10 changes: 10 additions & 0 deletions tests/ensemble/test_ranger_forest_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ def test_predict(self, boston_X, boston_y):
pred = rfr.predict(boston_X)
assert len(pred) == boston_X.shape[0]

# test with single record
boston_X_record = boston_X[0:1, :]
pred = rfr.predict(boston_X_record)
assert len(pred) == 1

def test_serialize(self, boston_X, boston_y):
tf = tempfile.TemporaryFile()
rfr = RangerForestRegressor()
Expand Down Expand Up @@ -125,6 +130,11 @@ def test_sample_fraction(self, iris_X, iris_y):
rfr.fit(iris_X, iris_y)
assert rfr.sample_fraction_ == [0.69]

# test with single record
iris_X_record = iris_X[0:1, :]
pred = rfr.predict(iris_X_record)
assert len(pred) == 1

def test_sample_fraction_replace(self, boston_X, boston_y, replace):
rfr = RangerForestRegressor(replace=replace)
rfr.fit(boston_X, boston_y)
Expand Down
10 changes: 10 additions & 0 deletions tests/ensemble/test_ranger_forest_survival.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@ def test_predict(self, lung_X, lung_y):
pred = rfs.predict(lung_X)
assert len(pred) == lung_X.shape[0]

# test with single record
lung_X_record = lung_X.values[0:1, :]
pred = rfs.predict(lung_X_record)
assert len(pred) == 1

def test_predict_cumulative_hazard_function(self, lung_X, lung_y):
rfs = RangerForestSurvival(n_estimators=N_ESTIMATORS)
rfs.fit(lung_X, lung_y)
Expand Down Expand Up @@ -140,6 +145,11 @@ def test_sample_fraction(self, lung_X, lung_y):
rfs.fit(lung_X, lung_y)
assert rfs.sample_fraction_ == [0.69]

# test with single record
lung_X_record = lung_X.values[0:1, :]
pred = rfs.predict(lung_X_record)
assert len(pred) == 1

def test_sample_fraction_replace(self, lung_X, lung_y, replace):
rfs = RangerForestSurvival(replace=replace)
rfs.fit(lung_X, lung_y)
Expand Down

0 comments on commit d90612d

Please sign in to comment.