Skip to content

Commit

Permalink
fix categorical dtype isinstance checks
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Mar 18, 2024
1 parent c775bbe commit ef4d4c1
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion explainerdashboard/explainer_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def merge_categorical_columns(
cat_pieces.append(pd.DataFrame({col_name: merged_col}))
else:
if not drop_regular:
if isinstance(X[col_name], pd.CategoricalDtype):
if isinstance(X[col_name].dtype, pd.CategoricalDtype):
cat_pieces.append(
pd.DataFrame({col_name: pd.Categorical(X[col_name])})
)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_catboost_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def test_get_col(precalculated_catboost_regression_explainer):
precalculated_catboost_regression_explainer.get_col("Sex"), pd.Series
)
assert isinstance(
precalculated_catboost_regression_explainer.get_col("Sex"), pd.CategoricalDtype
precalculated_catboost_regression_explainer.get_col("Sex").dtype,
pd.CategoricalDtype,
)

assert isinstance(
Expand Down
5 changes: 3 additions & 2 deletions tests/test_classifier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,13 @@ def test_get_col(precalculated_rf_classifier_explainer):
precalculated_rf_classifier_explainer.get_col("Gender"), pd.Series
)
assert isinstance(
precalculated_rf_classifier_explainer.get_col("Gender"), pd.CategoricalDtype
precalculated_rf_classifier_explainer.get_col("Gender").dtype,
pd.CategoricalDtype,
)

assert isinstance(precalculated_rf_classifier_explainer.get_col("Deck"), pd.Series)
assert isinstance(
precalculated_rf_classifier_explainer.get_col("Deck"), pd.CategoricalDtype
precalculated_rf_classifier_explainer.get_col("Deck").dtype, pd.CategoricalDtype
)

assert isinstance(precalculated_rf_classifier_explainer.get_col("Age"), pd.Series)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_regression_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def test_get_col(precalculated_rf_regression_explainer):
precalculated_rf_regression_explainer.get_col("Gender"), pd.Series
)
assert isinstance(
precalculated_rf_regression_explainer.get_col("Gender"), pd.CategoricalDtype
precalculated_rf_regression_explainer.get_col("Gender").dtype,
pd.CategoricalDtype,
)

assert isinstance(precalculated_rf_regression_explainer.get_col("Age"), pd.Series)
Expand Down

0 comments on commit ef4d4c1

Please sign in to comment.