Skip to content

Commit

Permalink
Fix divide by 0 in feature importance when no split is found. (#6676)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Feb 4, 2021
1 parent 72892cc commit a4101de
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
5 changes: 4 additions & 1 deletion python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,10 @@ def feature_importances_(self):
score = b.get_score(importance_type=self.importance_type)
all_features = [score.get(f, 0.) for f in b.feature_names]
all_features = np.array(all_features, dtype=np.float32)
return all_features / all_features.sum()
total = all_features.sum()
if total == 0:
return all_features
return all_features / total

@property
def coef_(self):
Expand Down
25 changes: 20 additions & 5 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,9 @@ def test_feature_importances_gain():
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
importance_type="gain",
use_label_encoder=False,
).fit(X, y)

exp = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.00326159, 0., 0., 0., 0., 0., 0., 0., 0.,
Expand All @@ -270,17 +272,30 @@ def test_feature_importances_gain():
y = pd.Series(digits['target'])
X = pd.DataFrame(digits['data'])
xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact",
random_state=0,
tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
importance_type="gain",
use_label_encoder=False,
).fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)

xgb_model = xgb.XGBClassifier(
random_state=0, tree_method="exact",
random_state=0,
tree_method="exact",
learning_rate=0.1,
importance_type="gain").fit(X, y)
importance_type="gain",
use_label_encoder=False,
).fit(X, y)
np.testing.assert_almost_equal(xgb_model.feature_importances_, exp)

# no split can be found
cls = xgb.XGBClassifier(
min_child_weight=1000, tree_method="hist", n_estimators=1, use_label_encoder=False
)
cls.fit(X, y)
assert np.all(cls.feature_importances_ == 0)


def test_select_feature():
from sklearn.datasets import load_digits
Expand Down

0 comments on commit a4101de

Please sign in to comment.