From ef4d4c11f5d952d7119f43758046f57f722951c3 Mon Sep 17 00:00:00 2001 From: Oege Dijk Date: Mon, 18 Mar 2024 21:30:45 +0100 Subject: [PATCH] fix categorical dtype isinstance checks --- explainerdashboard/explainer_methods.py | 2 +- tests/test_catboost_regression.py | 3 ++- tests/test_classifier_base.py | 5 +++-- tests/test_regression_base.py | 3 ++- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/explainerdashboard/explainer_methods.py b/explainerdashboard/explainer_methods.py index c286194..6997197 100644 --- a/explainerdashboard/explainer_methods.py +++ b/explainerdashboard/explainer_methods.py @@ -402,7 +402,7 @@ def merge_categorical_columns( cat_pieces.append(pd.DataFrame({col_name: merged_col})) else: if not drop_regular: - if isinstance(X[col_name], pd.CategoricalDtype): + if isinstance(X[col_name].dtype, pd.CategoricalDtype): cat_pieces.append( pd.DataFrame({col_name: pd.Categorical(X[col_name])}) ) diff --git a/tests/test_catboost_regression.py b/tests/test_catboost_regression.py index 164236a..431ef80 100644 --- a/tests/test_catboost_regression.py +++ b/tests/test_catboost_regression.py @@ -95,7 +95,8 @@ def test_get_col(precalculated_catboost_regression_explainer): precalculated_catboost_regression_explainer.get_col("Sex"), pd.Series ) assert isinstance( - precalculated_catboost_regression_explainer.get_col("Sex"), pd.CategoricalDtype + precalculated_catboost_regression_explainer.get_col("Sex").dtype, + pd.CategoricalDtype, ) assert isinstance( diff --git a/tests/test_classifier_base.py b/tests/test_classifier_base.py index 058896d..063dc73 100644 --- a/tests/test_classifier_base.py +++ b/tests/test_classifier_base.py @@ -127,12 +127,13 @@ def test_get_col(precalculated_rf_classifier_explainer): precalculated_rf_classifier_explainer.get_col("Gender"), pd.Series ) assert isinstance( - precalculated_rf_classifier_explainer.get_col("Gender"), pd.CategoricalDtype + precalculated_rf_classifier_explainer.get_col("Gender").dtype, + pd.CategoricalDtype, ) assert isinstance(precalculated_rf_classifier_explainer.get_col("Deck"), pd.Series) assert isinstance( - precalculated_rf_classifier_explainer.get_col("Deck"), pd.CategoricalDtype + precalculated_rf_classifier_explainer.get_col("Deck").dtype, pd.CategoricalDtype ) assert isinstance(precalculated_rf_classifier_explainer.get_col("Age"), pd.Series) diff --git a/tests/test_regression_base.py b/tests/test_regression_base.py index 9226576..5bc4839 100644 --- a/tests/test_regression_base.py +++ b/tests/test_regression_base.py @@ -76,7 +76,8 @@ def test_get_col(precalculated_rf_regression_explainer): precalculated_rf_regression_explainer.get_col("Gender"), pd.Series ) assert isinstance( - precalculated_rf_regression_explainer.get_col("Gender"), pd.CategoricalDtype + precalculated_rf_regression_explainer.get_col("Gender").dtype, + pd.CategoricalDtype, ) assert isinstance(precalculated_rf_regression_explainer.get_col("Age"), pd.Series)