Skip to content

Commit

Permalink
Interesting case.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 27, 2021
1 parent 29bbaa8 commit 6bbaac2
Showing 1 changed file with 20 additions and 16 deletions.
36 changes: 20 additions & 16 deletions demo/guide-python/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,31 +49,35 @@ def main() -> None:
# Use builtin categorical data support
# For scikit-learn interface, the input data must be pandas DataFrame or cudf
# DataFrame with categorical features
X, y = make_categorical(100, 10, 4, False)
X, y = make_categorical(100, 1, 2, False)
print(X)
# Specify `enable_categorical` to True.
tree_method = "approx"
reg = xgb.XGBRegressor(tree_method=tree_method, enable_categorical=True)
reg = xgb.XGBRegressor(
tree_method=tree_method, enable_categorical=True,
max_depth=1, n_estimators=1
)
reg.fit(X, y, eval_set=[(X, y)])

# Pass in already encoded data
X_enc, y_enc = make_categorical(100, 10, 4, True)
reg_enc = xgb.XGBRegressor(tree_method=tree_method)
X_enc, y_enc = make_categorical(100, 1, 2, True)
reg_enc = xgb.XGBRegressor(tree_method=tree_method, max_depth=1, n_estimators=1)
reg_enc.fit(X_enc, y_enc, eval_set=[(X_enc, y_enc)])

reg_results = np.array(reg.evals_result()["validation_0"]["rmse"])
reg_enc_results = np.array(reg_enc.evals_result()["validation_0"]["rmse"])
# reg_results = np.array(reg.evals_result()["validation_0"]["rmse"])
# reg_enc_results = np.array(reg_enc.evals_result()["validation_0"]["rmse"])

# Check that they have same results
np.testing.assert_allclose(reg_results, reg_enc_results)
# # Check that they have same results
# np.testing.assert_allclose(reg_results, reg_enc_results)

# Convert to DMatrix for SHAP value
booster: xgb.Booster = reg.get_booster()
m = xgb.DMatrix(X, enable_categorical=True) # specify categorical data support.
SHAP = booster.predict(m, pred_contribs=True)
margin = booster.predict(m, output_margin=True)
np.testing.assert_allclose(
np.sum(SHAP, axis=len(SHAP.shape) - 1), margin, rtol=1e-3
)
# # Convert to DMatrix for SHAP value
# booster: xgb.Booster = reg.get_booster()
# m = xgb.DMatrix(X, enable_categorical=True) # specify categorical data support.
# SHAP = booster.predict(m, pred_contribs=True)
# margin = booster.predict(m, output_margin=True)
# np.testing.assert_allclose(
# np.sum(SHAP, axis=len(SHAP.shape) - 1), margin, rtol=1e-3
# )


if __name__ == "__main__":
Expand Down

0 comments on commit 6bbaac2

Please sign in to comment.