Skip to content

Commit

Permalink
TST check all importance_types for XGBoost
Browse files Browse the repository at this point in the history
  • Loading branch information
kmike committed Jun 22, 2017
1 parent 443866a commit ac5e586
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions tests/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
)


def test_explain_xgboost(newsgroups_train):
_check_rf_classifier(newsgroups_train, XGBClassifier(n_estimators=10))
@pytest.mark.parametrize(['importance_type'], [['gain'], ['weight'], ['cover']])
def test_explain_xgboost(newsgroups_train, importance_type):
_check_rf_classifier(newsgroups_train, XGBClassifier(n_estimators=10),
importance_type=importance_type)


def test_explain_booster(newsgroups_train):
Expand Down Expand Up @@ -452,14 +454,14 @@ def test_check_booster_args():
regressor = XGBRegressor().fit(x, y)
classifier = XGBClassifier().fit(x, y)
booster, is_regression = _check_booster_args(regressor)
assert is_regression == True
assert is_regression is True
assert isinstance(booster, xgboost.Booster)
_, is_regression = _check_booster_args(regressor, is_regression=True)
assert is_regression == True
assert is_regression is True
_, is_regression = _check_booster_args(classifier)
assert is_regression == False
assert is_regression is False
_, is_regression = _check_booster_args(classifier, is_regression=False)
assert is_regression == False
assert is_regression is False
with pytest.raises(ValueError):
_check_booster_args(classifier, is_regression=True)
with pytest.raises(ValueError):
Expand All @@ -469,4 +471,4 @@ def test_check_booster_args():
assert _booster is booster
assert is_regression is None
_, is_regression = _check_booster_args(booster, is_regression=True)
assert is_regression == True
assert is_regression is True

0 comments on commit ac5e586

Please sign in to comment.