diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index b05b7d2..1b118c6 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -1,5 +1,35 @@ # Release Notes + +## 0.2.20: +### Breaking Changes +- `WhatIfComponent` deprecated. Use `WhatIfComposite` or connect components + yourself to a `FeatureInputComponent` +- renaming properties: + `explainer.cats` -> `explainer.onehot_cols` + `explainer.cats_dict` -> `explainer.onehot_dict` + +### New Features +- Adds support for model with categorical features that were not onehot encoded + (e.g. CatBoost) +- Adds filter on number of categories to display in violin plots and pdp plot, + and how to sort the categories (alphabetical, by frequency or by mean abs shap) + +### Bug Fixes +- fixes bug where str tab indicators returned e.g. the old ImportancesTab instead of ImportancesComposite +- + +### Improvements +- No longer dependening on PDPbox dependency: built own partial dependence + functions with categorical feature support +- autodetect xgboost.core.Booster or lightgbm.Booster and give ValueError to + use the sklearn compatible wrappers instead. + +### Other Changes +- Introduces list of categorical columns: `explainer.categorical_cols` +- Introduces dictionary with categorical columns categories: `explainer.categorical_dict` +- Introduces list of all categorical features: `explainer.cat_cols` + ## 0.2.19 ### Breaking Changes - ExplainerHub: parameter `user_json` is now called `users_file` (and default to a `users.yaml` file) diff --git a/TODO.md b/TODO.md index e809d42..5081101 100644 --- a/TODO.md +++ b/TODO.md @@ -3,6 +3,7 @@ ## Bugs: - dash contributions reload bug: Exception: Additivity check failed in TreeExplainer! +- shap dependence: when no point cloud, do not highlight! ## Layout: - Find a proper frontender to help :) @@ -20,20 +21,26 @@ - https://community.plotly.com/t/announcing-plotly-py-4-12-horizontal-and-vertical-lines-and-rectangles/46783 - add some of these: https://towardsdatascience.com/introducing-shap-decision-plots-52ed3b4a1cba - +- shap dependence plot, sort categorical features by: + - alphabet + - number of obs + - mean abs shap ### Classifier plots: - move predicted and actual to outer layer of ConfusionMatrixComponent - move predicted below graph? - pdp: add multiclass option - no icelines just mean and index with different thickness + - new method? ### Regression plots: + ## Explainers: +- minimize pd.DataFrame and np.array size: + - astype(float16), pd.category, etc - pass n_jobs to pdp_isolate -- autodetect xgboost booster or catboost.core and suggest XGBClassifier, etc - make X_cats with categorical encoding .astype("category") - add ExtraTrees and GradientBoostingClassifier to tree visualizers - add plain language explanations @@ -45,6 +52,7 @@ - rename RandomForestExplainer and XGBExplainer methods into something more logical - Breaking change! + ## notebooks: @@ -68,8 +76,8 @@ ### Components - autodetect when uuid name get rendered and issue warning -- Add side-by-side option to cutoff selector component +- Add side-by-side option to cutoff selector component - add filter to index selector using pattern matching callbacks: - https://dash.plotly.com/pattern-matching-callbacks - add querystring method to ExplainerComponents @@ -94,7 +102,6 @@ - Add this method? : https://arxiv.org/abs/2006.04750? ## Tests: -- add wizard test - add tests for InterpretML EBM (shap 0.37) - write tests for explainerhub CLI add user - test model_output='probability' and 'raw' or 'logodds' seperately @@ -102,6 +109,7 @@ - write tests for explainer_plots ## Docs: +- add cats_topx cats_sort to docs - add hide_wizard and wizard to docs - add hide_poweredby to docs - add Docker deploy example (from issue) diff --git a/explainerdashboard/dashboard_components/overview_components.py b/explainerdashboard/dashboard_components/overview_components.py index f2c0f11..7462d05 100644 --- a/explainerdashboard/dashboard_components/overview_components.py +++ b/explainerdashboard/dashboard_components/overview_components.py @@ -3,7 +3,6 @@ 'ImportancesComponent', 'FeatureInputComponent', 'PdpComponent', - 'WhatIfComponent', ] from math import ceil @@ -151,7 +150,7 @@ def __init__(self, explainer, title="Feature Importances", name=None, """ super().__init__(explainer, title, name) - if not self.explainer.cats: + if not self.explainer.onehot_cols: self.hide_cats = True assert importance_type in ['shap', 'permutation'], \ @@ -280,10 +279,11 @@ def __init__(self, explainer, title="Partial Dependence Plot", name=None, hide_title=False, hide_subtitle=False, hide_footer=False, hide_selector=False, hide_dropna=False, hide_sample=False, - hide_gridlines=False, hide_gridpoints=False, + hide_gridlines=False, hide_gridpoints=False, hide_cats_sort=False, feature_input_component=None, pos_label=None, col=None, index=None, cats=True, dropna=True, sample=100, gridlines=50, gridpoints=10, + cats_sort='freq', description=None, **kwargs): """Show Partial Dependence Plot component @@ -307,6 +307,7 @@ def __init__(self, explainer, title="Partial Dependence Plot", name=None, hide_sample (bool, optional): Hide sample size input. Defaults to False. hide_gridlines (bool, optional): Hide gridlines input. Defaults to False. hide_gridpoints (bool, optional): Hide gridpounts input. Defaults to False. + hide_cats_sort (bool, optional): Hide the categorical sorting dropdown. Defaults to False. feature_input_component (FeatureInputComponent): A FeatureInputComponent that will give the input to the graph instead of the index selector. If not None, hide_index=True. Defaults to None. @@ -319,6 +320,8 @@ def __init__(self, explainer, title="Partial Dependence Plot", name=None, sample (int, optional): Sample size to calculate average partial dependence. Defaults to 100. gridlines (int, optional): Number of ice lines to display in plot. Defaults to 50. gridpoints (int, optional): Number of breakpoints on horizontal axis Defaults to 10. + cats_sort (str, optional): how to sort categories: 'alphabet', + 'freq' or 'shap'. Defaults to 'freq'. description (str, optional): Tooltip to display when hover over component title. When None default text is shown. """ @@ -329,7 +332,7 @@ def __init__(self, explainer, title="Partial Dependence Plot", name=None, if self.col is None: self.col = self.explainer.columns_ranked_by_shap(self.cats)[0] - if not self.explainer.cats: + if not self.explainer.onehot_cols: self.hide_cats = True if self.feature_input_component is not None: @@ -432,7 +435,7 @@ def layout(self): ]), hide=self.hide_dropna), make_hideable( dbc.Col([ - dbc.Label("Pdp sample size:", id='pdp-sample-label-'+self.name ), + dbc.Label("Sample:", id='pdp-sample-label-'+self.name ), dbc.Tooltip("Number of observations to use to calculate average partial dependence", target='pdp-sample-label-'+self.name ), dbc.Input(id='pdp-sample-'+self.name, value=self.sample, @@ -455,11 +458,34 @@ def layout(self): dbc.Input(id='pdp-gridpoints-'+self.name, value=self.gridpoints, type="number", min=0, max=100, step=1), ]), hide=self.hide_gridpoints), + make_hideable( + html.Div([ + dbc.Col([ + html.Label('Sort categories:', id='pdp-categories-sort-label-'+self.name), + dbc.Tooltip("How to sort the categories: alphabetically, most common " + "first (Frequency), or highest mean absolute SHAP value first (Shap impact)", + target='pdp-categories-sort-label-'+self.name), + dbc.Select(id='pdp-categories-sort-'+self.name, + options = [{'label': 'Alphabetically', 'value': 'alphabet'}, + {'label': 'Frequency', 'value': 'freq'}, + {'label': 'Shap impact', 'value': 'shap'}], + value=self.cats_sort), + ])], + id='pdp-categories-sort-div-'+self.name, + style={} if self.col in self.explainer.cat_cols else dict(display="none") + ), hide=self.hide_cats_sort), ], form=True), ]), hide=self.hide_footer) ]) def component_callbacks(self, app): + + @app.callback( + Output('pdp-categories-sort-div-'+self.name, 'style'), + Input('pdp-col-'+self.name, 'value') + ) + def update_pdp_sort_div(col): + return {} if col in self.explainer.cat_cols else dict(display="none") @app.callback( Output('pdp-col-'+self.name, 'options'), @@ -480,12 +506,13 @@ def update_pdp_graph(cats, pos_label): Input('pdp-sample-'+self.name, 'value'), Input('pdp-gridlines-'+self.name, 'value'), Input('pdp-gridpoints-'+self.name, 'value'), + Input('pdp-categories-sort-'+self.name, 'value'), Input('pos-label-'+self.name, 'value')] ) - def update_pdp_graph(index, col, drop_na, sample, gridlines, gridpoints, pos_label): + def update_pdp_graph(index, col, drop_na, sample, gridlines, gridpoints, sort, pos_label): return self.explainer.plot_pdp(col, index, - drop_na=bool(drop_na), sample=sample, gridlines=gridlines, gridpoints=gridpoints, - pos_label=pos_label) + drop_na=bool(drop_na), sample=sample, gridlines=gridlines, + gridpoints=gridpoints, sort=sort, pos_label=pos_label) else: @app.callback( Output('pdp-graph-'+self.name, 'figure'), @@ -494,14 +521,15 @@ def update_pdp_graph(index, col, drop_na, sample, gridlines, gridpoints, pos_lab Input('pdp-sample-'+self.name, 'value'), Input('pdp-gridlines-'+self.name, 'value'), Input('pdp-gridpoints-'+self.name, 'value'), + Input('pdp-categories-sort-'+self.name, 'value'), Input('pos-label-'+self.name, 'value'), *self.feature_input_component._feature_callback_inputs] ) - def update_pdp_graph(col, drop_na, sample, gridlines, gridpoints, pos_label, *inputs): + def update_pdp_graph(col, drop_na, sample, gridlines, gridpoints, sort, pos_label, *inputs): X_row = self.explainer.get_row_from_input(inputs, ranked_by_shap=True) return self.explainer.plot_pdp(col, X_row=X_row, - drop_na=bool(drop_na), sample=sample, gridlines=gridlines, gridpoints=gridpoints, - pos_label=pos_label) + drop_na=bool(drop_na), sample=sample, gridlines=gridlines, + gridpoints=gridpoints, sort=sort, pos_label=pos_label) class FeatureInputComponent(ExplainerComponent): @@ -543,18 +571,16 @@ def __init__(self, explainer, title="Feature Input", name=None, self._input_features = self.explainer.columns_ranked_by_shap(cats=True) self._feature_inputs = [ self._generate_dash_input( - feature, self.explainer.cats, self.explainer.cats_dict) + feature, self.explainer.onehot_cols, self.explainer.onehot_dict, self.explainer.categorical_dict) for feature in self._input_features] self._feature_callback_inputs = [Input('feature-input-'+feature+'-input-'+self.name, 'value') for feature in self._input_features] self._feature_callback_outputs = [Output('feature-input-'+feature+'-input-'+self.name, 'value') for feature in self._input_features] if self.description is None: self.description = """ Adjust the input values to see predictions for what if scenarios.""" - def _generate_dash_input(self, col, cats, cats_dict): - if col in cats: - col_values = [ - col_val[len(col)+1:] if col_val.startswith(col+"_") else col_val - for col_val in cats_dict[col]] + def _generate_dash_input(self, col, onehot_cols, onehot_dict, cat_dict): + if col in cat_dict: + col_values = cat_dict[col] return dbc.FormGroup([ dbc.Label(col), dcc.Dropdown(id='feature-input-'+col+'-input-'+self.name, @@ -562,6 +588,19 @@ def _generate_dash_input(self, col, cats, cats_dict): clearable=False), dbc.FormText(f"Select any {col}") if not self.hide_range else None, ]) + elif col in onehot_cols: + col_values = onehot_dict[col] + display_values = [ + col_val[len(col)+1:] if col_val.startswith(col+"_") else col_val + for col_val in col_values] + return dbc.FormGroup([ + dbc.Label(col), + dcc.Dropdown(id='feature-input-'+col+'-input-'+self.name, + options=[dict(label=display, value=col_val) + for display, col_val in zip(display_values, col_values)], + clearable=False), + dbc.FormText(f"Select any {col}") if not self.hide_range else None, + ]) else: min_range = np.round(self.explainer.X[col][lambda x: x != self.explainer.na_fill].min(), 2) max_range = np.round(self.explainer.X[col][lambda x: x != self.explainer.na_fill].max(), 2) @@ -626,155 +665,5 @@ def update_whatif_inputs(index): return feature_values -class WhatIfComponent(ExplainerComponent): - def __init__(self, explainer, title="What if...", name=None, - hide_title=False, hide_subtitle=False, hide_index=False, - hide_selector=False, hide_contributions=False, hide_pdp=False, - index=None, pdp_col=None, pos_label=None, description=None, - **kwargs): - """Interaction Dependence Component. - - Args: - explainer (Explainer): explainer object constructed with either - ClassifierExplainer() or RegressionExplainer() - title (str, optional): Title of tab or page. Defaults to - "What if...". - name (str, optional): unique name to add to Component elements. - If None then random uuid is generated to make sure - it's unique. Defaults to None. - hide_title (bool, optional): hide the title - hide_subtitle (bool, optional): Hide subtitle. Defaults to False. - hide_index (bool, optional): hide the index selector - hide_selector (bool, optional): hide the pos_label selector - hide_contributions (bool, optional): hide the contributions graph - hide_pdp (bool, optional): hide the pdp graph - index (str, int, optional): default index - pdp_col (str, optional): default pdp feature col - pos_label ({int, str}, optional): initial pos label. - Defaults to explainer.pos_label - description (str, optional): Tooltip to display when hover over - component title. When None default text is shown. - """ - super().__init__(explainer, title, name) - - assert len(explainer.columns) == len(set(explainer.columns)), \ - "Not all column names are unique, so cannot launch whatif component/tab!" - - if self.pdp_col is None: - self.pdp_col = self.explainer.columns_ranked_by_shap(cats=True)[0] - - self.index_name = 'whatif-index-'+self.name - - self._input_features = self.explainer.columns_cats - self._feature_inputs = [ - self._generate_dash_input( - feature, self.explainer.cats, self.explainer.cats_dict) - for feature in self._input_features] - self._feature_callback_inputs = [Input('whatif-'+feature+'-input-'+self.name, 'value') for feature in self._input_features] - self._feature_callback_outputs = [Output('whatif-'+feature+'-input-'+self.name, 'value') for feature in self._input_features] - - self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label) - - self.register_dependencies('preds', 'shap_values') - - - def _generate_dash_input(self, col, cats, cats_dict): - if col in cats: - col_values = [ - col_val[len(col)+1:] if col_val.startswith(col+"_") else col_val - for col_val in cats_dict[col]] - return html.Div([ - html.P(col), - dcc.Dropdown(id='whatif-'+col+'-input-'+self.name, - options=[dict(label=col_val, value=col_val) for col_val in col_values], - clearable=False) - ]) - else: - return html.Div([ - html.P(col), - dbc.Input(id='whatif-'+col+'-input-'+self.name, type="number"), - ]) - - def layout(self): - return dbc.Card([ - make_hideable( - dbc.CardHeader([ - dbc.Row([ - dbc.Col([ - html.H1(self.title) - ]), - ]), - ]), hide=self.hide_title), - dbc.CardBody([ - dbc.Row([ - make_hideable( - dbc.Col([ - dbc.Label(f"{self.explainer.index_name}:"), - dcc.Dropdown(id='whatif-index-'+self.name, - options = [{'label': str(idx), 'value':idx} - for idx in self.explainer.idxs], - value=self.index) - ], md=4), hide=self.hide_index), - make_hideable( - dbc.Col([self.selector.layout() - ], md=2), hide=self.hide_selector), - ], form=True), - dbc.Row([ - dbc.Col([ - html.H3("Edit Feature input:") - ]) - ]), - dbc.Row([ - dbc.Col([ - *self._feature_inputs[:int((len(self._feature_inputs) + 1)/2)] - ]), - dbc.Col([ - *self._feature_inputs[int((len(self._feature_inputs) + 1)/2):] - ]), - ]), - dbc.Row([ - make_hideable( - dbc.Col([ - html.H3("Prediction and contributions:"), - dcc.Graph(id='whatif-contrib-graph-'+self.name, - config=dict(modeBarButtons=[['toImage']], displaylogo=False)), - ]), hide=self.hide_contributions), - make_hideable( - dbc.Col([ - html.H3("Partial dependence:"), - dcc.Dropdown(id='whatif-pdp-col-'+self.name, - options=[dict(label=col, value=col) for col in self._input_features], - value=self.pdp_col), - dcc.Graph(id='whatif-pdp-graph-'+self.name, - config=dict(modeBarButtons=[['toImage']], displaylogo=False)), - ]), hide=self.hide_pdp), - ]) - ]) - ]) - def component_callbacks(self, app): - @app.callback( - [Output('whatif-contrib-graph-'+self.name, 'figure'), - Output('whatif-pdp-graph-'+self.name, 'figure')], - [Input('whatif-pdp-col-'+self.name, 'value'), - Input('pos-label-'+self.name, 'value'), - *self._feature_callback_inputs, - ], - ) - def update_whatif_plots(pdp_col, pos_label, *input_args): - X_row = pd.DataFrame(dict(zip(self._input_features, input_args)), index=[0]).fillna(0) - contrib_plot = self.explainer.plot_shap_contributions(X_row=X_row, pos_label=pos_label) - pdp_plot = self.explainer.plot_pdp(pdp_col, X_row=X_row, pos_label=pos_label) - return contrib_plot, pdp_plot - - @app.callback( - [*self._feature_callback_outputs], - [Input('whatif-index-'+self.name, 'value')] - ) - def update_whatif_inputs(index): - idx = self.explainer.get_int_idx(index) - if idx is None: - raise PreventUpdate - feature_values = self.explainer.X_cats.iloc[[idx]].values[0].tolist() - return feature_values diff --git a/explainerdashboard/dashboard_components/regression_components.py b/explainerdashboard/dashboard_components/regression_components.py index 04e6cb4..bce7e73 100644 --- a/explainerdashboard/dashboard_components/regression_components.py +++ b/explainerdashboard/dashboard_components/regression_components.py @@ -915,9 +915,9 @@ def layout(self): "When you have some real outliers it can help to remove them" " from the plot so it is easier to see the overall pattern.", target='reg-vs-col-winsor-label-'+self.name), - dbc.Input(id='reg-vs-col-winsor-'+self.name, - value=self.winsor, - type="number", min=0, max=49, step=1), + dbc.Input(id='reg-vs-col-winsor-'+self.name, + value=self.winsor, + type="number", min=0, max=49, step=1), ], md=4), hide=self.hide_winsor), make_hideable( dbc.Col([ @@ -951,7 +951,10 @@ def register_callbacks(self, app): Input('reg-vs-col-winsor-'+self.name, 'value')], ) def update_residuals_graph(col, display, points, winsor): - style = {} if col in self.explainer.cats else dict(display="none") + if col in self.explainer.onehot_cols or col in self.explainer.categorical_cols: + style = {} + else: + style = dict(display="none") if display == 'observed': return self.explainer.plot_y_vs_feature( col, points=bool(points), winsor=winsor, dropna=True), style diff --git a/explainerdashboard/dashboard_components/shap_components.py b/explainerdashboard/dashboard_components/shap_components.py index 82d36ef..ab85642 100644 --- a/explainerdashboard/dashboard_components/shap_components.py +++ b/explainerdashboard/dashboard_components/shap_components.py @@ -56,7 +56,7 @@ def __init__(self, explainer, title='Shap Summary', name=None, """ super().__init__(explainer, title, name) - if self.explainer.cats is None or not self.explainer.cats: + if not self.explainer.onehot_cols: self.hide_cats = True if self.depth is not None: @@ -201,9 +201,12 @@ def __init__(self, explainer, title='Shap Dependence', name=None, subtitle="Relationship between feature value and SHAP value", hide_title=False, hide_subtitle=False, hide_cats=False, hide_col=False, hide_color_col=False, hide_index=False, - hide_selector=False, - pos_label=None, cats=True, col=None, - color_col=None, index=None, description=None, **kwargs): + hide_selector=False, hide_cats_topx=False, hide_cats_sort=False, + hide_footer=False, + pos_label=None, cats=True, + col=None, color_col=None, index=None, + cats_topx=10, cats_sort='freq', + description=None, **kwargs): """Show shap dependence graph Args: @@ -222,6 +225,9 @@ def __init__(self, explainer, title='Shap Dependence', name=None, hide_color_col (bool, optional): hide color feature selector Defaults to False. hide_index (bool, optional): hide index selector Defaults to False. hide_selector (bool, optional): hide pos label selector. Defaults to False. + hide_cats_topx (bool, optional): hide the categories topx input. Defaults to False. + hide_cats_sort (bool, optional): hide the categories sort selector.Defaults to False. + hide_footer (bool, optional): hide the footer. pos_label ({int, str}, optional): initial pos label. Defaults to explainer.pos_label cats (bool, optional): group cats. Defaults to True. @@ -229,6 +235,10 @@ def __init__(self, explainer, title='Shap Dependence', name=None, color_col (str, optional): Color plot by values of this Feature. Defaults to None. index (int, optional): Highlight a particular index. Defaults to None. + cats_topx (int, optional): number of categories to display for + categorical features. + cats_sort (str, optional): how to sort categories: 'alphabet', + 'freq' or 'shap'. Defaults to 'freq'. description (str, optional): Tooltip to display when hover over component title. When None default text is shown. """ @@ -324,12 +334,42 @@ def layout(self): dcc.Graph(id='shap-dependence-graph-'+self.name, config=dict(modeBarButtons=[['toImage']], displaylogo=False))]), ]), + make_hideable( + dbc.CardFooter([ + html.Div([ + dbc.Row([ + make_hideable( + dbc.Col([ + dbc.Label("Categories:", id='shap-dependence-n-categories-label-'+self.name), + dbc.Tooltip("Number of categories to display", + target='shap-dependence-n-categories-label-'+self.name), + dbc.Input(id='shap-dependence-n-categories-'+self.name, + value=self.cats_topx, + type="number", min=1, max=50, step=1), + ], md=2), self.hide_cats_topx), + make_hideable( + dbc.Col([ + html.Label('Sort categories:', id='shap-dependence-categories-sort-label-'+self.name), + dbc.Tooltip("How to sort the categories: alphabetically, most common " + "first (Frequency), or highest mean absolute SHAP value first (Shap impact)", + target='shap-dependence-categories-sort-label-'+self.name), + dbc.Select(id='shap-dependence-categories-sort-'+self.name, + options = [{'label': 'Alphabetically', 'value': 'alphabet'}, + {'label': 'Frequency', 'value': 'freq'}, + {'label': 'Shap impact', 'value': 'shap'}], + value=self.cats_sort), + ], md=4), hide=self.hide_cats_sort), + ]) + ], id='shap-dependence-categories-div-'+self.name, + style={} if self.col in self.explainer.cat_cols else dict(display="none")) + ]), hide=self.hide_footer), ]) def component_callbacks(self, app): @app.callback( [Output('shap-dependence-color-col-'+self.name, 'options'), - Output('shap-dependence-color-col-'+self.name, 'value')], + Output('shap-dependence-color-col-'+self.name, 'value'), + Output('shap-dependence-categories-div-'+self.name, 'style')], [Input('shap-dependence-col-'+self.name, 'value')], [State('shap-dependence-group-cats-'+self.name, 'value'), State('pos-label-'+self.name, 'value')]) @@ -339,21 +379,29 @@ def set_color_col_dropdown(col, cats, pos_label): options = ([{'label': col, 'value':col} for col in sorted_interact_cols] + [dict(label="None", value="no_color_col")]) - value = sorted_interact_cols[1] - return (options, value) + if col in self.explainer.cat_cols: + value = None + style = dict() + else: + value = sorted_interact_cols[1] + style = dict(display="none") + return (options, value, style) @app.callback( Output('shap-dependence-graph-'+self.name, 'figure'), [Input('shap-dependence-color-col-'+self.name, 'value'), Input('shap-dependence-index-'+self.name, 'value'), + Input('shap-dependence-n-categories-'+self.name, 'value'), + Input('shap-dependence-categories-sort-'+self.name, 'value'), Input('pos-label-'+self.name, 'value')], [State('shap-dependence-col-'+self.name, 'value')]) - def update_dependence_graph(color_col, index, pos_label, col): + def update_dependence_graph(color_col, index, topx, sort, pos_label, col): if col is not None: if color_col =="no_color_col": color_col, index = None, None return self.explainer.plot_shap_dependence( - col, color_col, highlight_index=index, pos_label=pos_label) + col, color_col, topx=topx, sort=sort, + highlight_index=index, pos_label=pos_label) raise PreventUpdate @app.callback( @@ -447,7 +495,7 @@ def __init__(self, explainer, title="Interactions Summary", name=None, self.col = self.explainer.columns_ranked_by_shap(self.cats)[0] if self.depth is not None: self.depth = min(self.depth, self.explainer.n_features(self.cats)-1) - if not self.explainer.cats: + if not self.explainer.onehot_cols: self.hide_cats = True self.index_name = 'interaction-summary-index-'+self.name self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label) @@ -607,8 +655,10 @@ def __init__(self, explainer, title="Interaction Dependence", name=None, subtitle="Relation between feature value and shap interaction value", hide_title=False, hide_subtitle=False, hide_cats=False, hide_col=False, hide_interact_col=False, hide_index=False, - hide_selector=False, hide_top=False, hide_bottom=False, + hide_selector=False, hide_cats_topx=False, hide_cats_sort=False, + hide_top=False, hide_bottom=False, pos_label=None, cats=True, col=None, interact_col=None, + cats_topx=10, cats_sort='freq', description=None, index=None, **kwargs): """Interaction Dependence Component. @@ -635,6 +685,10 @@ def __init__(self, explainer, title="Interaction Dependence", name=None, Defaults to False. hide_selector (bool, optional): hide pos label selector. Defaults to False. + hide_cats_topx (bool, optional): hide the categories topx input. + Defaults to False. + hide_cats_sort (bool, optional): hide the categories sort selector. + Defaults to False. hide_top (bool, optional): Hide the top interaction graph (col vs interact_col). Defaults to False. hide_bottom (bool, optional): hide the bottom interaction graph @@ -645,6 +699,10 @@ def __init__(self, explainer, title="Interaction Dependence", name=None, col (str, optional): Feature to find interactions for. Defaults to None. interact_col (str, optional): Feature to interact with. Defaults to None. highlight (int, optional): Index row to highlight Defaults to None. + cats_topx (int, optional): number of categories to display for + categorical features. + cats_sort (str, optional): how to sort categories: 'alphabet', + 'freq' or 'shap'. Defaults to 'freq'. description (str, optional): Tooltip to display when hover over component title. When None default text is shown. """ @@ -735,21 +793,76 @@ def layout(self): dbc.Row([ dbc.Col([ make_hideable( - dcc.Loading(id='loading-interaction-dependence-graph-'+self.name, - children=[dcc.Graph(id='interaction-dependence-graph-'+self.name, + dcc.Loading(id='loading-interaction-dependence-top-graph-'+self.name, + children=[dcc.Graph(id='interaction-dependence-top-graph-'+self.name, config=dict(modeBarButtons=[['toImage']], displaylogo=False))]), hide=self.hide_top), ]), ]), + html.Div([ + dbc.Row([ + make_hideable( + dbc.Col([ + dbc.Label("Categories:", id='interaction-dependence-top-n-categories-label-'+self.name), + dbc.Tooltip("Number of categories to display", + target='interaction-dependence-top-n-categories-label-'+self.name), + dbc.Input(id='interaction-dependence-top-n-categories-'+self.name, + value=self.cats_topx, + type="number", min=1, max=50, step=1), + ], md=2), self.hide_cats_topx), + make_hideable( + dbc.Col([ + html.Label('Sort categories:', id='interaction-dependence-top-categories-sort-label-'+self.name), + dbc.Tooltip("How to sort the categories: alphabetically, most common " + "first (Frequency), or highest mean absolute SHAP value first (Shap impact)", + target='interaction-dependence-top-categories-sort-label-'+self.name), + dbc.Select(id='interaction-dependence-top-categories-sort-'+self.name, + options = [{'label': 'Alphabetically', 'value': 'alphabet'}, + {'label': 'Frequency', 'value': 'freq'}, + {'label': 'Shap impact', 'value': 'shap'}], + value=self.cats_sort), + ], md=4), hide=self.hide_cats_sort), + ]) + ], id='interaction-dependence-top-categories-div-'+self.name, + style={} if self.interact_col in self.explainer.cat_cols else dict(display="none")), + dbc.Row([ dbc.Col([ make_hideable( - dcc.Loading(id='loading-reverse-interaction-graph-'+self.name, - children=[dcc.Graph(id='interaction-dependence-reverse-graph-'+self.name, + dcc.Loading(id='loading-reverse-interaction-bottom-graph-'+self.name, + children=[dcc.Graph(id='interaction-dependence-bottom-graph-'+self.name, config=dict(modeBarButtons=[['toImage']], displaylogo=False))]), hide=self.hide_bottom), ]), ]), + + html.Div([ + dbc.Row([ + make_hideable( + dbc.Col([ + dbc.Label("Categories:", id='interaction-dependence-bottom-n-categories-label-'+self.name), + dbc.Tooltip("Number of categories to display", + target='interaction-dependence-bottom-n-categories-label-'+self.name), + dbc.Input(id='interaction-dependence-bottom-n-categories-'+self.name, + value=self.cats_topx, + type="number", min=1, max=50, step=1), + ], md=2), self.hide_cats_topx), + make_hideable( + dbc.Col([ + html.Label('Sort categories:', id='interaction-dependence-bottom-categories-sort-label-'+self.name), + dbc.Tooltip("How to sort the categories: alphabetically, most common " + "first (Frequency), or highest mean absolute SHAP value first (Shap impact)", + target='interaction-dependence-bottom-categories-sort-label-'+self.name), + dbc.Select(id='interaction-dependence-bottom-categories-sort-'+self.name, + options = [{'label': 'Alphabetically', 'value': 'alphabet'}, + {'label': 'Frequency', 'value': 'freq'}, + {'label': 'Shap impact', 'value': 'shap'}], + value=self.cats_sort), + ], md=4), hide=self.hide_cats_sort), + ]) + ], id='interaction-dependence-bottom-categories-div-'+self.name, + style={} if self.col in self.explainer.cat_cols else dict(display="none")), + ]), ]) @@ -778,18 +891,39 @@ def update_interaction_dependence_interact_col(col, pos_label, cats, old_interac raise PreventUpdate @app.callback( - [Output('interaction-dependence-graph-'+self.name, 'figure'), - Output('interaction-dependence-reverse-graph-'+self.name, 'figure')], + [Output('interaction-dependence-top-graph-'+self.name, 'figure'), + Output('interaction-dependence-top-categories-div-'+self.name, 'style')], + [Input('interaction-dependence-interact-col-'+self.name, 'value'), + Input('interaction-dependence-index-'+self.name, 'value'), + Input('interaction-dependence-top-n-categories-'+self.name, 'value'), + Input('interaction-dependence-top-categories-sort-'+self.name, 'value'), + Input('pos-label-'+self.name, 'value'), + Input('interaction-dependence-col-'+self.name, 'value')]) + def update_dependence_graph(interact_col, index, topx, sort, pos_label, col): + if col is not None and interact_col is not None: + style = {} if interact_col in self.explainer.cat_cols else dict(display="none") + return (self.explainer.plot_shap_interaction( + col, interact_col, highlight_index=index, pos_label=pos_label, + topx=topx, sort=sort), + style) + raise PreventUpdate + + @app.callback( + [Output('interaction-dependence-bottom-graph-'+self.name, 'figure'), + Output('interaction-dependence-bottom-categories-div-'+self.name, 'style')], [Input('interaction-dependence-interact-col-'+self.name, 'value'), Input('interaction-dependence-index-'+self.name, 'value'), + Input('interaction-dependence-bottom-n-categories-'+self.name, 'value'), + Input('interaction-dependence-bottom-categories-sort-'+self.name, 'value'), Input('pos-label-'+self.name, 'value'), Input('interaction-dependence-col-'+self.name, 'value')]) - def update_dependence_graph(interact_col, index, pos_label, col): + def update_dependence_graph(interact_col, index, topx, sort, pos_label, col): if col is not None and interact_col is not None: + style = {} if col in self.explainer.cat_cols else dict(display="none") return (self.explainer.plot_shap_interaction( - col, interact_col, highlight_index=index, pos_label=pos_label), - self.explainer.plot_shap_interaction( - interact_col, col, highlight_index=index, pos_label=pos_label)) + interact_col, col, highlight_index=index, pos_label=pos_label, + topx=topx, sort=sort), + style) raise PreventUpdate @@ -890,7 +1024,7 @@ def __init__(self, explainer, title="Contributions Plot", name=None, if self.depth is not None: self.depth = min(self.depth, self.explainer.n_features(self.cats)) - if not self.explainer.cats: + if not self.explainer.onehot_cols: self.hide_cats = True if self.feature_input_component is not None: @@ -1098,7 +1232,7 @@ def __init__(self, explainer, title="Contributions Table", name=None, if self.depth is not None: self.depth = min(self.depth, self.explainer.n_features(self.cats)) - if not self.explainer.cats: + if not self.explainer.onehot_cols: self.hide_cats = True if self.feature_input_component is not None: diff --git a/explainerdashboard/dashboard_methods.py b/explainerdashboard/dashboard_methods.py index 4ecb2d1..a685ad8 100644 --- a/explainerdashboard/dashboard_methods.py +++ b/explainerdashboard/dashboard_methods.py @@ -175,11 +175,9 @@ def __init__(self, explainer, title=None, name=None): """ self._store_child_params(no_param=['explainer']) if not hasattr(self, "name") or self.name is None: - self.name = "uuid"+shortuuid.ShortUUID().random(length=5) - if title is not None: - self.title = title - if not hasattr(self, "title"): - self.title = "Custom" + self.name = name or "uuid"+shortuuid.ShortUUID().random(length=5) + if not hasattr(self, "title") or self.title is None: + self.title = title or "Custom" self._components = [] self._dependencies = [] @@ -423,18 +421,22 @@ def instantiate_component(component, explainer, name=None, **kwargs): """ if inspect.isclass(component) and issubclass(component, ExplainerComponent): - init_argspec = inspect.getargspec(component.__init__) - if not init_argspec.keywords: - kwargs = {k:v for k,v in kwargs.items() if k in init_argspec.args} - if "name" in init_argspec.args: + init_argspec = inspect.getfullargspec(component.__init__) + assert len(init_argspec.args) > 1 and init_argspec.args[1] == 'explainer', \ + (f"The first parameter of {component.__name__}.__init__ should be 'explainer'. " + f"Instead the __init__ is: {component.__name__}.__init__{inspect.signature(component.__init__)}") + if not init_argspec.varkw: + kwargs = {k:v for k,v in kwargs.items() if k in init_argspec.args + init_argspec.kwonlyargs} + if "name" in init_argspec.args+init_argspec.kwonlyargs: component = component(explainer, name=name, **kwargs) else: print(f"ExplainerComponent {component} does not accept a name parameter, " - f"so cannot assign name={name}!" - "Make sure to set name explicitly yourself if you want to " - "deploy across multiple workers or a cluster, as otherwise " - "each instance in the cluster will generate its own random " - "uuid name!") + f"so cannot assign name='{name}': " + f"{component.__name__}.__init__{inspect.signature(component.__init__)}. " + "Make sure to set super().__init__(name=...) explicitly yourself " + "inside the __init__ if you want to deploy across multiple " + "workers or a cluster, as otherwise each instance in the " + "cluster will generate its own random uuid name!") component = component(explainer, **kwargs) return component elif isinstance(component, ExplainerComponent): diff --git a/explainerdashboard/dashboards.py b/explainerdashboard/dashboards.py index 0ab46cd..2a7c9c5 100644 --- a/explainerdashboard/dashboards.py +++ b/explainerdashboard/dashboards.py @@ -415,18 +415,6 @@ def __init__(self, explainer=None, tabs=None, assert 'BaseExplainer' in str(explainer.__class__.mro()), \ ("explainer should be an instance of BaseExplainer, such as " "ClassifierExplainer or RegressionExplainer!") - - if self.explainer.cats_only: - print("Note: explainer contains a model and data that deal with " - "categorical features directly. Not all elements of the " - "ExplainerDashboard are compatible with such models, and " - "so setting the following **kwargs: " - "cats=True, hide_cats=True, shap_interaction=False", flush=True) - kwargs.update(dict( - cats=True, hide_cats=True, shap_interaction=False)) - if kwargs: - print("**kwargs: Passing the following keyword arguments to all the dashboard" - f" ExplainerComponents: {', '.join([f'{k}={v}' for k,v in kwargs.items()])}...") if tabs is None: tabs = [] @@ -434,7 +422,7 @@ def __init__(self, explainer=None, tabs=None, print("No y labels were passed to the Explainer, so setting" " model_summary=False...", flush=True) model_summary = False - if shap_interaction and (not explainer.interactions_should_work or self.explainer.cats_only): + if shap_interaction and (not explainer.interactions_should_work): print("For this type of model and model_output interactions don't " "work, so setting shap_interaction=False...", flush=True) shap_interaction = False @@ -670,19 +658,22 @@ def _store_params(self, no_store=None, no_attr=None, no_param=None): def _convert_str_tabs(self, component): if isinstance(component, str): if component == 'importances': - return ImportancesTab + return ImportancesComposite elif component == 'model_summary': - return ModelSummaryTab + if self.explainer.is_classifier: + return ClassifierModelStatsComposite + else: + return RegressionModelStatsComposite elif component == 'contributions': - return ContributionsTab + return IndividualPredictionsComposite elif component == 'whatif': - return WhatIfTab + return WhatIfComposite elif component == 'shap_dependence': - return ShapDependenceTab + return ShapDependenceComposite elif component == 'shap_interaction': - return ShapInteractionsTab + return ShapInteractionsComposite elif component == 'decision_trees': - return DecisionTreesTab + return DecisionTreesComposite return component @staticmethod @@ -1831,13 +1822,6 @@ def pdp(self, title="Partial Dependence Plots", **kwargs): comp = PdpComponent(self._explainer, **kwargs) self._run_component(comp, title) - @delegates_kwargs(WhatIfComponent) - @delegates_doc(WhatIfComponent) - def whatif(self, title="What if...", **kwargs): - """Show What if... component inline in notebook""" - comp = WhatIfComponent(self._explainer, **kwargs) - self._run_component(comp, title) - class InlineExplainerComponent: def __init__(self, inline_explainer, name): diff --git a/explainerdashboard/explainer_methods.py b/explainerdashboard/explainer_methods.py index ea7ef38..7054205 100644 --- a/explainerdashboard/explainer_methods.py +++ b/explainerdashboard/explainer_methods.py @@ -70,7 +70,7 @@ def guess_shap(model): def parse_cats(X, cats, sep:str="_"): - """parse onehot encoded columns to a cats_dict. + """parse onehot encoded columns to a onehot_dict. - cats can be a dict where you enumerate each individual onehot encoded column belonging to each categorical feature, e.g. cats={ 'Sex':['Sex_female', 'Sex_male'], @@ -83,40 +83,43 @@ def parse_cats(X, cats, sep:str="_"): Asserts that all columns can be found in X.columns. Asserts that all columns are only passed once. """ - cols = X.columns + all_cols = X.columns + onehot_cols = [] + onehot_dict = {} + col_counter = Counter() - cats_dict = {} + if isinstance(cats, dict): for k, v in cats.items(): - assert set(v).issubset(set(cols)), \ - f"These cats columns for {k} could not be found in X.columns: {set(v)-set(cols)}!" + assert set(v).issubset(set(all_cols)), \ + f"These cats columns for {k} could not be found in X.columns: {set(v)-set(all_cols)}!" col_counter.update(v) - cats_dict = cats + onehot_dict = cats elif isinstance(cats, list): for cat in cats: if isinstance(cat, str): - cats_dict[cat] = [c for c in cols if c.startswith(cat + sep)] - col_counter.update(cats_dict[cat]) + onehot_dict[cat] = [c for c in all_cols if c.startswith(cat + sep)] + col_counter.update(onehot_dict[cat]) if isinstance(cat, dict): for k, v in cat.items(): - assert set(v).issubset(set(cols)), \ - f"These cats columns for {k} could not be found in X.columns: {set(v)-set(cols)}!" + assert set(v).issubset(set(all_cols)), \ + f"These cats columns for {k} could not be found in X.columns: {set(v)-set(all_cols)}!" col_counter.update(v) - cats_dict[k] = v + onehot_dict[k] = v multi_cols = [v for v, c in col_counter.most_common() if c > 1] assert not multi_cols, \ (f"The following columns seem to have been passed to cats multiple times: {multi_cols}. " "Please make sure that each onehot encoded column is only assigned to one cat column!") - assert not set(cats_dict.keys()) & set(cols), \ - (f"These new cats columns are already in X.columns: {list(set(cats_dict.keys()) & set(cols))}! " + assert not set(onehot_dict.keys()) & set(all_cols), \ + (f"These new cats columns are already in X.columns: {list(set(onehot_dict.keys()) & set(all_cols))}! " "Please select a different name for your new cats columns!") for col, count in col_counter.most_common(): assert set(X[col].astype(int).unique()).issubset({0,1}), \ f"{col} is not a onehot encoded column (i.e. has values other than 0, 1)!" - cats_list = list(cats_dict.keys()) - for col in [col for col in cols if col not in col_counter.keys()]: - cats_dict[col] = [col] - return cats_list, cats_dict + onehot_cols = list(onehot_dict.keys()) + for col in [col for col in all_cols if col not in col_counter.keys()]: + onehot_dict[col] = [col] + return onehot_cols, onehot_dict @@ -199,14 +202,16 @@ def retrieve_onehot_value(X, encoded_col, onehot_cols, sep="_"): # if not a single 1 then encoded feature must have been dropped feature_value[np.max(X[onehot_cols].values, axis=1) == 0] = -1 - mapping = {-1: "NOT_ENCODED"} - col_values = [col[len(encoded_col)+1:] if col.startswith(encoded_col+sep) - else col for col in onehot_cols] - mapping.update({i: col for i, col in enumerate(col_values)}) + if all([col.startswith(col+"_") for col in onehot_cols]): + mapping = {-1: encoded_col+"_NOT_ENCODED"} + else: + mapping = {-1: "NOT_ENCODED"} + + mapping.update({i: col for i, col in enumerate(onehot_cols)}) return pd.Series(feature_value).map(mapping).values -def merge_categorical_columns(X, cats_dict=None, sep="_"): +def merge_categorical_columns(X, onehot_dict=None, sep="_"): """ Returns a new feature Dataframe X_cats where the onehotencoded categorical features have been merged back with the old value retrieved @@ -215,7 +220,7 @@ def merge_categorical_columns(X, cats_dict=None, sep="_"): Args: X (pd.DataFrame): original dataframe with onehotencoded columns, e.g. columns=['Age', 'Sex_Male', 'Sex_Female"]. - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} sep (str): separator used in the encoding, e.g. "_" for Sex_Male. Defaults to "_". @@ -224,21 +229,32 @@ def merge_categorical_columns(X, cats_dict=None, sep="_"): pd.DataFrame, with onehot encodings merged back into categorical columns. """ X_cats = X.copy() - for col_name, col_list in cats_dict.items(): + for col_name, col_list in onehot_dict.items(): if len(col_list) > 1: X_cats[col_name] = retrieve_onehot_value(X, col_name, col_list, sep) X_cats.drop(col_list, axis=1, inplace=True) return X_cats -def X_cats_to_X(X_cats, cats_dict, X_columns, sep="_"): +def remove_cat_names(X_cats, onehot_dict): + """removes the leading category names in the onehotencoded columns. + Turning e.g 'Sex_male' into 'male', etc""" + X_cats = X_cats.copy() + for cat, cols in onehot_dict.items(): + if len(cols) > 1: + mapping = {c:c[len(cat)+1:] for c in cols if c.startswith(cat+'_')} + X_cats.loc[:, cat] = X_cats.loc[:, cat].map(mapping, na_action='ignore').values + return X_cats + + +def X_cats_to_X(X_cats, onehot_dict, X_columns, sep="_"): """ re-onehotencodes a dataframe where onehotencoded columns had previously been merged with merge_categorical_columns(...) Args: X_cats (pd.DataFrame): dataframe with merged categorical columns cats - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} X_columns: list of columns of original dataframe @@ -247,18 +263,14 @@ def X_cats_to_X(X_cats, cats_dict, X_columns, sep="_"): """ non_cat_cols = [col for col in X_cats.columns if col in X_columns] X_new = X_cats[non_cat_cols].copy() - for cat, labels in cats_dict.items(): - if len(labels) > 1: - for label in labels: - if label.startswith(cat+sep): - label_val = label[len(cat)+len(sep):] - else: - label_val = label - X_new[label] = (X_cats[cat]==label_val).astype(int) + for cat, cols in onehot_dict.items(): + if len(cols) > 1: + for col in cols: + X_new[col] = (X_cats[cat]==col).astype(np.int8) return X_new[X_columns] -def merge_categorical_shap_values(X, shap_values, cats_dict=None, sep="_"): +def merge_categorical_shap_values(X, shap_values, onehot_dict=None, sep="_"): """ Returns a new feature new shap values np.array where the shap values of onehotencoded categorical features have been @@ -269,13 +281,13 @@ def merge_categorical_shap_values(X, shap_values, cats_dict=None, sep="_"): in the shap_values np.ndarray. shap_values (np.ndarray): numpy array of shap values, output of e.g. shap.TreeExplainer(X).shap_values() - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} sep (str): seperator used between variable and category. Defaults to "_". """ shap_df = pd.DataFrame(shap_values, columns=X.columns) - for col_name, col_list in cats_dict.items(): + for col_name, col_list in onehot_dict.items(): if len(col_list) > 1: shap_df[col_name] = shap_df[col_list].sum(axis=1) shap_df.drop(col_list, axis=1, inplace=True) @@ -283,7 +295,7 @@ def merge_categorical_shap_values(X, shap_values, cats_dict=None, sep="_"): def merge_categorical_shap_interaction_values(shap_interaction_values, - old_columns, new_columns, cats_dict): + old_columns, new_columns, onehot_dict): """ Returns a 3d numpy array shap_interaction_values where the onehot-encoded categorical columns have been added up together. @@ -299,7 +311,7 @@ def merge_categorical_shap_interaction_values(shap_interaction_values, e.g. ["Age", "Sex_Male", "Sex_Female"] new_columns (list of str): list of column names without onehotencodings, e.g. ["Age", "Sex"] - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} Returns: @@ -323,9 +335,9 @@ def merge_categorical_shap_interaction_values(shap_interaction_values, newcol_idx1 = new_columns.index(new_col1) newcol_idx2 = new_columns.index(new_col2) oldcol_idxs1 = [old_columns.index(col) - for col in cats_dict[new_col1]] + for col in onehot_dict[new_col1]] oldcol_idxs2 = [old_columns.index(col) - for col in cats_dict[new_col2]] + for col in onehot_dict[new_col2]] siv[:, newcol_idx1, newcol_idx2] = \ shap_interaction_values[:, oldcol_idxs1, :][:, :, oldcol_idxs2]\ .sum(axis=(1, 2)) @@ -359,7 +371,7 @@ def _scorer(clf, X, y): return _scorer -def permutation_importances(model, X, y, metric, cats_dict=None, +def permutation_importances(model, X, y, metric, onehot_dict=None, greater_is_better=True, needs_proba=False, pos_label=1, n_repeats=1, n_jobs=None, sort=True, verbose=0): """ @@ -372,7 +384,7 @@ def permutation_importances(model, X, y, metric, cats_dict=None, y (pd.Series): series of targets metric: metric to be evaluated (usually R2 for regression, roc_auc for classification) - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} greater_is_better (bool): indicates whether the higher score on the metric indicates a better model. @@ -388,8 +400,8 @@ def permutation_importances(model, X, y, metric, cats_dict=None, """ X = X.copy() - if cats_dict is None: - cats_dict = {col:[col] for col in X.columns} + if onehot_dict is None: + onehot_dict = {col:[col] for col in X.columns} if isinstance(metric, str): scorer = make_scorer(metric, greater_is_better=greater_is_better, needs_proba=needs_proba) @@ -412,7 +424,7 @@ def _permutation_importance(model, X, y, scorer, col_name, col_list, baseline, n scores = Parallel(n_jobs=n_jobs)(delayed(_permutation_importance)( model, X, y, scorer, col_name, col_list, baseline, n_repeats - ) for col_name, col_list in cats_dict.items()) + ) for col_name, col_list in onehot_dict.items()) importances_df = pd.DataFrame(scores, columns=['Feature', 'Score']) importances_df['Importance'] = baseline - importances_df['Score'] @@ -423,7 +435,7 @@ def _permutation_importance(model, X, y, scorer, col_name, col_list, baseline, n return importances_df -def cv_permutation_importances(model, X, y, metric, cats_dict=None, greater_is_better=True, +def cv_permutation_importances(model, X, y, metric, onehot_dict=None, greater_is_better=True, needs_proba=False, pos_label=None, cv=None, n_repeats=1, n_jobs=None, verbose=0): """ @@ -435,7 +447,7 @@ def cv_permutation_importances(model, X, y, metric, cats_dict=None, greater_is_b y (pd.Series): series of targets metric: metric to be evaluated (usually R2 for regression, roc_auc for classification) - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} greater_is_better (bool): indicates whether the higher score on the metric indicates a better model. @@ -448,7 +460,7 @@ def cv_permutation_importances(model, X, y, metric, cats_dict=None, greater_is_b verbose (int): set to 1 to print output for debugging. Defaults to 0. """ if cv is None: - return permutation_importances(model, X, y, metric, cats_dict, + return permutation_importances(model, X, y, metric, onehot_dict, greater_is_better=greater_is_better, needs_proba=needs_proba, pos_label=pos_label, @@ -465,7 +477,7 @@ def cv_permutation_importances(model, X, y, metric, cats_dict=None, greater_is_b model.fit(X_train, y_train) - imp = permutation_importances(model, X_test, y_test, metric, cats_dict, + imp = permutation_importances(model, X_test, y_test, metric, onehot_dict, greater_is_better=greater_is_better, needs_proba=needs_proba, pos_label=pos_label, @@ -482,23 +494,23 @@ def cv_permutation_importances(model, X, y, metric, cats_dict=None, greater_is_b .sort_values('Importance', ascending=False) -def mean_absolute_shap_values(columns, shap_values, cats_dict=None): +def mean_absolute_shap_values(columns, shap_values, onehot_dict=None): """ Returns a dataframe with the mean absolute shap values for each feature. Args: columns (list of str): list of column names shap_values (np.ndarray): 2d array of SHAP values - cats_dict (dict): dict of features with lists for onehot-encoded variables, + onehot_dict (dict): dict of features with lists for onehot-encoded variables, e.g. {'Fare': ['Fare'], 'Sex' : ['Sex_male', 'Sex_Female']} Returns: pd.DataFrame with columns 'Feature' and 'MEAN_ABS_SHAP'. """ - if cats_dict is None: - cats_dict = {col:[col] for col in columns} + if onehot_dict is None: + onehot_dict = {col:[col] for col in columns} shap_abs_mean_dict = {} - for col_name, col_list in cats_dict.items(): + for col_name, col_list in onehot_dict.items(): shap_abs_mean_dict[col_name] = np.absolute( shap_values[:, [columns.index(col) for col in col_list]].sum(axis=1) ).mean() @@ -510,9 +522,41 @@ def mean_absolute_shap_values(columns, shap_values, cats_dict=None): }).sort_values('MEAN_ABS_SHAP', ascending=False).reset_index(drop=True) return shap_df +def get_grid_points(array, n_grid_points=10, min_percentage=0, max_percentage=100): + """seperates a numerical array into a number of grid points. Helper function + for get_pdp_df. + + Args: + array (np.array): array + n_grid_points (int, optional): number of points to divide array in. + Defaults to 10. + min_percentage (int, optional): Minimum percentage to start at, + ignoring outliers. Defaults to 0. + max_percentage (int, optional): Maximum percentage to reach, ignoring + outliers. Defaults to 100. + + Raises: + ValueError: [description] + + Returns: + np.array + """ + + if isinstance(array, pd.Series): + array = array.values + else: + array = np.array(array) + if not is_numeric_dtype(array): + raise ValueError("array should be a numeric dtype!") + + percentile_grids = np.linspace(start=min_percentage, stop=max_percentage, num=n_grid_points) + value_grids = np.percentile(array, percentile_grids) + return value_grids + def get_pdp_df(model, X_sample:pd.DataFrame, feature:Union[str, List], pos_label=1, - n_grid_points=10, min_percentage=0, max_percentage=100): + n_grid_points:int=10, min_percentage:int=0, max_percentage:int=100, + multiclass:bool=False, grid_values:List=None): """Returns a dataframe with partial dependence for every row in X_sample for a number of feature values Args: @@ -530,31 +574,36 @@ def get_pdp_df(model, X_sample:pd.DataFrame, feature:Union[str, List], pos_label max_percentage (int, optional): For numeric features: maximum percentage of samples to end x axis by. If smaller than 100 a form of winsorizing the x axis. Defaults to 100. + multiclass (bool, optional): for classifier models, return a list of dataframes, + one for each predicted label. + grid_values (list, optional): list of grid values. Default to None, in which + case it will be inferred from X_sample. """ - def get_grid_points(array, n_grid_points=10, min_percentage=0, max_percentage=100): - if not is_numeric_dtype(array): - raise ValueError("array should be a numeric dtype!") - if isinstance(array, pd.Series): - array = array.values - percentile_grids = np.linspace(start=min_percentage, stop=max_percentage, num=n_grid_points) - value_grids = np.percentile(array, percentile_grids) - return value_grids - - if isinstance(feature, str): - if not is_numeric_dtype(X_sample[feature]): - grid_values = sorted(X_sample[feature].unique().tolist()) + + + if grid_values is None: + if isinstance(feature, str): + if not is_numeric_dtype(X_sample[feature]): + grid_values = sorted(X_sample[feature].unique().tolist()) + else: + grid_values = get_grid_points(X_sample[feature], + n_grid_points=n_grid_points, + min_percentage=min_percentage, + max_percentage=max_percentage).tolist() + elif isinstance(feature, list): + grid_values = feature else: - grid_values = get_grid_points(X_sample[feature], - n_grid_points=n_grid_points, - min_percentage=min_percentage, - max_percentage=max_percentage).tolist() - elif isinstance(feature, list): - grid_values = feature - else: - raise ValueError("feature should either be a column name (str), " - "or a list of onehot-encoded columns!") + raise ValueError("feature should either be a column name (str), " + "or a list of onehot-encoded columns!") - pdp_df = pd.DataFrame() + if hasattr(model, "predict_proba"): + n_labels = model.predict_proba(X_sample.iloc[[0]]).shape[1] + if multiclass: + pdp_dfs = [pd.DataFrame() for i in range(n_labels)] + else: + pdp_df = pd.DataFrame() + else: + pdp_df = pd.DataFrame() for grid_value in grid_values: dtemp = X_sample.copy() if isinstance(feature, list): @@ -565,12 +614,19 @@ def get_grid_points(array, n_grid_points=10, min_percentage=0, max_percentage=10 else: dtemp.loc[:, feature] = grid_value if hasattr(model, "predict_proba"): - preds = model.predict_proba(dtemp)[:, pos_label] + pred_probas = model.predict_proba(dtemp) + if multiclass: + for i in range(n_labels): + pdp_dfs[i][grid_value] = pred_probas[:, i] + else: + pdp_df[grid_value] = pred_probas[:, pos_label] else: preds = model.predict(dtemp) - pdp_df[grid_value] = preds - - return pdp_df + pdp_df[grid_value] = preds + if multiclass: + return pdp_dfs + else: + return pdp_df def get_precision_df(pred_probas, y_true, bin_size=None, quantiles=None, diff --git a/explainerdashboard/explainer_plots.py b/explainerdashboard/explainer_plots.py index ba7b093..0b2b250 100644 --- a/explainerdashboard/explainer_plots.py +++ b/explainerdashboard/explainer_plots.py @@ -827,7 +827,8 @@ def plotly_dependence_plot(X, shap_values, col_name, interact_col_name=None, def plotly_shap_violin_plot(X, shap_values, col_name, color_col=None, points=False, - interaction=False, units="", highlight_index=None, idxs=None, index_name="index"): + interaction=False, units="", highlight_index=None, idxs=None, index_name="index", + cats_order=None): """Generates a violin plot for displaying shap value distributions for categorical features. @@ -856,10 +857,13 @@ def plotly_shap_violin_plot(X, shap_values, col_name, color_col=None, points=Fal x = X[col_name].copy() shaps = shap_values[:, X.columns.get_loc(col_name)] - n_cats = X[col_name].nunique() + if cats_order is None: + cats_order = sorted(X[col_name].unique().tolist()) + + n_cats = len(cats_order) if idxs is not None: - assert len(idxs)==X.shape[0] + assert len(idxs)==X.shape[0]==len(shaps) idxs = np.array([str(idx) for idx in idxs]) else: idxs = np.array([str(i) for i in range(X.shape[0])]) @@ -879,9 +883,10 @@ def plotly_shap_violin_plot(X, shap_values, col_name, color_col=None, points=Fal else: fig = make_subplots(rows=1, cols=n_cats, shared_yaxes=True) - fig.update_yaxes(range=[shaps.min()*1.3 if shaps.min() < 0 else shaps.min()*0.76, shaps.max()*1.3]) + shap_range = shaps.max() - shaps.min() + fig.update_yaxes(range=[shaps.min()-0.1*shap_range, shaps.max()+0.1*shap_range]) - for i, cat in enumerate(X[col_name].unique()): + for i, cat in enumerate(cats_order): col = 1+i*2 if points or color_col is not None else 1+i fig.add_trace(go.Violin( x=x[x == cat], @@ -1012,7 +1017,7 @@ def plotly_pdp(pdp_df, Defaults to None. index_feature_value (str, float, optional): value of feature for index. Defaults to None. - index_prediction (float, optional): Final prediction for index. + index_prediction (float, optional): Baseline prediction for index. Defaults to None. absolute (bool, optional): Display absolute pdp lines. If false then display relative to base. Defaults to True. @@ -1120,7 +1125,12 @@ def plotly_pdp(pdp_df, ) ) - annotations.append(go.layout.Annotation(x=pdp_df.columns[int(0.5*len(pdp_df.columns))], y=index_prediction, text=f"baseline pred = {np.round(index_prediction,2)}")) + annotations.append( + go.layout.Annotation( + x=pdp_df.columns[int(0.5*len(pdp_df.columns))], + y=index_prediction, + text=f"baseline pred = {str(np.round(index_prediction,round))}") + ) fig.update_layout(annotations=annotations) fig.update_layout(shapes=shapes) diff --git a/explainerdashboard/explainers.py b/explainerdashboard/explainers.py index f0cfd7f..3ef5c43 100644 --- a/explainerdashboard/explainers.py +++ b/explainerdashboard/explainers.py @@ -95,19 +95,26 @@ def __init__(self, model, X, y=None, permutation_metric=r2_score, self.X, self.X_background = X, X_background self.model = model - if not all([is_numeric_dtype(X[col]) for col in X.columns]): - self.cats_only = True - self.cats = [col for col in X.columns if not is_numeric_dtype(X[col])] - self.cats_dict = {col:self.X[col].unique().tolist() for col in self.cats} - print("Warning: detected non-numeric columns in X! " - f"Autodetecting the following categorical columns: {self.cats}. \n" - "Setting self.cats_only=True, which means that passing cats=False " - "to explainer methods will not work, and shap interaction values " - "will not work... ExplainerDashboard will disable these features " - " by default.", flush=True) - else: - self.cats_only = False - self.cats, self.cats_dict = parse_cats(self.X, cats) + if safe_is_instance(model, "xgboost.core.Booster"): + raise ValueError("For xgboost models, currently only the scikit-learn " + "compatible wrappers xgboost.sklearn.XGBClassifier and " + "xgboost.sklearn.XGBRegressor are supported, so please use those " + "instead of xgboost.Booster!") + + if safe_is_instance(model, "lightgbm.Booster"): + raise ValueError("For lightgbm, currently only the scikit-learn " + "compatible wrappers lightgbm.LGBMClassifier and lightgbm.LGBMRegressor " + "are supported, so please use those instead of lightgbm.Booster!") + + self.onehot_cols, self.onehot_dict = parse_cats(self.X, cats) + self.categorical_cols = [col for col in X.columns if not is_numeric_dtype(X[col])] + self.categorical_dict = {col:sorted(X[col].unique().tolist()) for col in self.categorical_cols} + self.cat_cols = self.onehot_cols + self.categorical_cols + if self.categorical_cols: + print(f"Warning: Detected the following categorical columns: {self.categorical_cols}." + "Unfortunately for now shap interaction values do not work with" + "categorical columns.", flush=True) + self.interactions_should_work = False if y is not None: self.y = pd.Series(y) @@ -170,7 +177,10 @@ def __init__(self, model, X, y=None, permutation_metric=r2_score, self.is_classifier = False self.is_regression = False self.interactions_should_work = True - + if safe_is_instance(self.model, "CatBoostRegressor", "CatBoostClassifier"): + self.interactions_should_work = False + else: + self.interactions_should_work = True @classmethod def from_file(cls, filepath): @@ -295,8 +305,6 @@ def check_cats(self, col1, col2=None): Boolean whether cats should be True """ - if self.cats_only: - return True if col2 is None: if col1 in self.columns: return False @@ -460,7 +468,7 @@ def columns_ranked_by_shap(self, cats=False, pos_label=None): list of columns """ - if cats or self.cats_only: + if cats: return self.mean_abs_shap_cats(pos_label).Feature.tolist() else: return self.mean_abs_shap(pos_label).Feature.tolist() @@ -475,7 +483,7 @@ def n_features(self, cats=False): int, number of features """ - if cats or self.cats_only: + if cats: return len(self.columns_cats) else: return len(self.columns) @@ -497,16 +505,55 @@ def equivalent_col(self, col): Returns: col """ - if self.cats_only: - return col - if col in self.cats: + if col in self.columns_cats: # first onehot-encoded columns - return self.cats_dict[col][0] + return self.onehot_dict[col][0] elif col in self.columns: # the cat that the col belongs to - return [k for k, v in self.cats_dict.items() if col in v][0] + return [k for k, v in self.onehot_dict.items() if col in v][0] return None + def ordered_cats(self, col, topx=None, sort='alphabet'): + """Return a list of categories in an categorical column, sorted + by mode. + + Args: + col (str): Categorical feature to return categories for. + topx (int, optional): Return topx top categories. Defaults to None. + sort (str, optional): Sorting method, either alphabetically ('alphabet'), + by frequency ('freq') or mean absolute shap ('shap'). + Defaults to 'alphabet'. + + Raises: + ValueError: if sort is other than 'alphabet', 'freq', 'shap + + Returns: + list + """ + assert col in self.cat_cols, \ + f"{col} is not a categorical feature!" + if sort=='alphabet': + if topx is None: + return sorted(self.X_cats[col].unique().tolist()) + else: + return sorted(self.X_cats[col].unique().tolist())[:topx] + elif sort=='freq': + if topx is None: + return self.X_cats[col].value_counts().index.tolist() + else: + return self.X_cats[col].value_counts().nlargest(topx).index.tolist() + elif sort=='shap': + if topx is None: + return (pd.Series(self.shap_values_cats[:, self.columns_cats.index(col)], + index=self.X_cats[col]).abs().groupby(level=0).mean() + .sort_values(ascending=False).index.tolist()) + else: + return (pd.Series(self.shap_values_cats[:, self.columns_cats.index(col)], + index=self.X_cats[col]).abs().groupby(level=0).mean() + .sort_values(ascending=False).nlargest(topx).index.tolist()) + else: + raise ValueError(f"sort='{sort}', but should be in {{'alphabet', 'freq', 'shap'}}") + def get_row_from_input(self, inputs:List, ranked_by_shap=False): """returns a single row pd.DataFrame from a given list of *inputs""" if len(inputs)==1 and isinstance(inputs[0], list): @@ -567,14 +614,14 @@ def get_col(self, col): pd.Series with values of col """ - assert col in self.columns or col in self.cats, \ + assert col in self.columns or col in self.onehot_cols, \ f"{col} not in columns!" if col in self.X.columns: return self.X[col] - elif col in self.cats: + elif col in self.onehot_cols: return pd.Series(retrieve_onehot_value( - self.X, col, self.cats_dict[col]), name=col) + self.X, col, self.onehot_dict[col]), name=col) def get_col_value_plus_prediction(self, col, index=None, X_row=None, pos_label=None): """return value of col and prediction for either index or X_row @@ -589,7 +636,7 @@ def get_col_value_plus_prediction(self, col, index=None, X_row=None, pos_label=N """ - assert (col in self.X.columns) or (col in self.cats),\ + assert (col in self.X.columns) or (col in self.onehot_cols),\ f"{col} not in columns of dataset" if index is not None: assert index in self, f"index {index} not found" @@ -597,8 +644,8 @@ def get_col_value_plus_prediction(self, col, index=None, X_row=None, pos_label=N if col in self.X.columns: col_value = self.X[col].iloc[idx] - elif col in self.cats: - col_value = retrieve_onehot_value(self.X, col, self.cats_dict[col])[idx] + elif col in self.onehot_cols: + col_value = retrieve_onehot_value(self.X, col, self.onehot_dict[col])[idx] if self.is_classifier: if pos_label is None: @@ -615,15 +662,15 @@ def get_col_value_plus_prediction(self, col, index=None, X_row=None, pos_label=N if ((len(X_row.columns) == len(self.X_cats.columns)) and (X_row.columns == self.X_cats.columns).all()): - X_row = X_cats_to_X(X_row, self.cats_dict, self.X.columns) + X_row = X_cats_to_X(X_row, self.onehot_dict, self.X.columns) else: assert (X_row.columns == self.X.columns).all(), \ "X_row should have the same columns as self.X or self.X_cats!" if col in X_row.columns: col_value = X_row[col].item() - elif col in self.cats: - col_value = retrieve_onehot_value(X_row, col, self.cats_dict[col]).item() + elif col in self.onehot_cols: + col_value = retrieve_onehot_value(X_row, col, self.onehot_dict[col]).item() if self.is_classifier: if pos_label is None: @@ -654,16 +701,9 @@ def permutation_importances(self): def permutation_importances_cats(self): """permutation importances with categoricals grouped""" if not hasattr(self, '_perm_imps_cats'): - if self.cats_only: - self._perm_imps_cats = cv_permutation_importances( - self.model, self.X, self.y, self.metric, - cv=self.permutation_cv, - n_jobs=self.n_jobs, - needs_proba=self.is_classifier) - else: - self._perm_imps_cats = cv_permutation_importances( + self._perm_imps_cats = cv_permutation_importances( self.model, self.X, self.y, self.metric, - cats_dict=self.cats_dict, + onehot_dict=self.onehot_dict, cv=self.permutation_cv, n_jobs=self.n_jobs, needs_proba=self.is_classifier) @@ -673,10 +713,7 @@ def permutation_importances_cats(self): def X_cats(self): """X with categorical variables grouped together""" if not hasattr(self, '_X_cats'): - if self.cats_only: - self._X_cats = self.X - else: - self._X_cats = merge_categorical_columns(self.X, self.cats_dict) + self._X_cats = merge_categorical_columns(self.X, self.onehot_dict) return self._X_cats @property @@ -714,11 +751,8 @@ def shap_values(self): def shap_values_cats(self): """SHAP values when categorical features have been grouped""" if not hasattr(self, '_shap_values_cats'): - if self.cats_only: - self._shap_values_cats = self.shap_explainer.shap_values(self.X) - else: - self._shap_values_cats = merge_categorical_shap_values( - self.X, self.shap_values, self.cats_dict) + self._shap_values_cats = merge_categorical_shap_values( + self.X, self.shap_values, self.onehot_dict) return make_callable(self._shap_values_cats) @property @@ -745,7 +779,7 @@ def shap_interaction_values_cats(self): if not hasattr(self, '_shap_interaction_values_cats'): self._shap_interaction_values_cats = \ merge_categorical_shap_interaction_values( - self.shap_interaction_values, self.X, self.X_cats, self.cats_dict) + self.shap_interaction_values, self.X, self.X_cats, self.onehot_dict) return make_callable(self._shap_interaction_values_cats) @property @@ -782,12 +816,12 @@ def calculate_properties(self, include_interactions=True): self.mean_abs_shap) if not self.y_missing: _ = self.permutation_importances - if self.cats is not None: + if self.onehot_cols: _ = (self.mean_abs_shap_cats, self.X_cats, self.shap_values_cats) if self.interactions_should_work and include_interactions: _ = self.shap_interaction_values - if self.cats is not None: + if self.onehot_cols: _ = self.shap_interaction_values_cats def metrics(self, *args, **kwargs): @@ -813,7 +847,7 @@ def mean_abs_shap_df(self, topx=None, cutoff=None, cats=False, pos_label=None): pd.DataFrame: shap_df """ - if cats or self.cats_only: + if cats: shap_df = self.mean_abs_shap_cats(pos_label) else: shap_df = self.mean_abs_shap(pos_label) @@ -839,7 +873,7 @@ def shap_top_interactions(self, col, topx=None, cats=False, pos_label=None): list: top_interactions """ - if cats or self.cats_only: + if cats: if hasattr(self, '_shap_interaction_values'): col_idx = self.X_cats.columns.get_loc(col) top_interactions = self.X_cats.columns[ @@ -888,7 +922,7 @@ def shap_interaction_values_by_col(self, col, cats=False, pos_label=None): np.array(N,N): shap_interaction_values """ - if cats or self.cats_only: + if cats: return self.shap_interaction_values_cats(pos_label)[:, self.X_cats.columns.get_loc(col), :] else: @@ -915,7 +949,7 @@ def permutation_importances_df(self, topx=None, cutoff=None, cats=False, pd.DataFrame: importance_df """ - if cats or self.cats_only: + if cats: importance_df = self.permutation_importances_cats(pos_label) else: importance_df = self.permutation_importances(pos_label) @@ -986,14 +1020,13 @@ def contrib_df(self, index=None, X_row=None, cats=True, topx=None, cutoff=None, if X_row is not None: if ((len(X_row.columns) == len(self.X_cats.columns)) and (X_row.columns == self.X_cats.columns).all()): - if cats or self.cats_only: + if cats: X_row_cats = X_row - if not self.cats_only: - X_row = X_cats_to_X(X_row, self.cats_dict, self.X.columns) + X_row = X_cats_to_X(X_row, self.onehot_dict, self.X.columns) else: assert (X_row.columns == self.X.columns).all(), \ "X_row should have the same columns as self.X or self.X_cats!" - X_row_cats = merge_categorical_columns(X_row, self.cats_dict) + X_row_cats = merge_categorical_columns(X_row, self.onehot_dict) shap_values = self.shap_explainer.shap_values(X_row) if self.is_classifier: @@ -1002,10 +1035,10 @@ def contrib_df(self, index=None, X_row=None, cats=True, topx=None, cutoff=None, shap_values = shap_values[self.get_pos_label_index(pos_label)] if cats: - if not self.cats_only: - shap_values = merge_categorical_shap_values(X_row, shap_values, self.cats_dict) + shap_values = merge_categorical_shap_values(X_row, shap_values, self.onehot_dict) return get_contrib_df(self.shap_base_value(pos_label), shap_values[0], - X_row_cats, topx, cutoff, sort, cols) + remove_cat_names(X_row_cats, self.onehot_dict), + topx, cutoff, sort, cols) else: return get_contrib_df(self.shap_base_value(pos_label), shap_values[0], X_row, topx, cutoff, sort, cols) @@ -1014,7 +1047,8 @@ def contrib_df(self, index=None, X_row=None, cats=True, topx=None, cutoff=None, if cats: return get_contrib_df(self.shap_base_value(pos_label), self.shap_values_cats(pos_label)[idx], - self.X_cats.iloc[[idx]], topx, cutoff, sort, cols) + remove_cat_names(self.X_cats.iloc[[idx]], self.onehot_dict), + topx, cutoff, sort, cols) else: return get_contrib_df(self.shap_base_value(pos_label), self.shap_values(pos_label)[idx], @@ -1064,7 +1098,7 @@ def interactions_df(self, col, cats=False, topx=None, cutoff=None, """ importance_df = mean_absolute_shap_values( - self.columns_cats if (cats or self.cats_only) else self.columns, + self.columns_cats if cats else self.columns, self.shap_interaction_values_by_col(col, cats, pos_label)) if topx is None: topx = len(importance_df) @@ -1090,12 +1124,12 @@ def formatted_contrib_df(self, index, round=None, lang='en', pos_label=None): cdf.loc[cdf.col=='base_value', 'value'] = np.nan cdf['row_id'] = self.get_int_idx(index) cdf['name_id'] = index - cdf['cat_value'] = np.where(cdf.col.isin(self.cats), cdf.value, np.nan) - cdf['cont_value'] = np.where(cdf.col.isin(self.cats), np.nan, cdf.value) + cdf['cat_value'] = np.where(cdf.col.isin(self.onehot_cols), cdf.value, np.nan) + cdf['cont_value'] = np.where(cdf.col.isin(self.onehot_cols), np.nan, cdf.value) if round is not None: rounded_cont = np.round(cdf['cont_value'].values.astype(float), round) - cdf['value'] = np.where(cdf.col.isin(self.cats), cdf.cat_value, rounded_cont) - cdf['type'] = np.where(cdf.col.isin(self.cats), 'cat', 'cont') + cdf['value'] = np.where(cdf.col.isin(self.onehot_cols), cdf.cat_value, rounded_cont) + cdf['type'] = np.where(cdf.col.isin(self.onehot_cols), 'cat', 'cont') cdf['abs_contribution'] = np.abs(cdf.contribution) cdf = cdf[['row_id', 'name_id', 'contribution', 'abs_contribution', 'col', 'value', 'cat_value', 'cont_value', 'type', 'index']] @@ -1108,14 +1142,55 @@ def formatted_contrib_df(self, index, round=None, lang='en', pos_label=None): 'Cat_Value', 'Cont_Value', 'Value_Type', 'Feature_Order'] return cdf - def pdp_df(self, col, index=None, X_row=None, drop_na=True, - sample=500, num_grid_points=20, pos_label=None): - assert col in self.X.columns or col in self.cats, \ + def pdp_df(self, col, index=None, X_row=None, drop_na=True, sample=500, + n_grid_points=10, pos_label=None, sort='freq'): + """Return a pdp_df for generating partial dependence plots. + + Args: + col (str): Feature to generate partial dependence for. + index ({int, str}, optional): Index to include on first row + of pdp_df. Defaults to None. + X_row (pd.DataFrame, optional): Single row to put on first row of pdp_df. + Defaults to None. + drop_na (bool, optional): Drop self.na_fill values. Defaults to True. + sample (int, optional): Sample size for pdp_df. Defaults to 500. + n_grid_points (int, optional): Number of grid points on x axis. + Defaults to 10. + pos_label ([type], optional): [description]. Defaults to None. + sort (str, optional): For categorical features: how to sort: + 'alphabet', 'freq', 'shap'. Defaults to 'freq'. + + Returns: + pd.DataFrame + """ + assert col in self.X.columns or col in self.onehot_cols, \ f"{col} not in columns of dataset" - if col in self.cats and not self.cats_only: - features = self.cats_dict[col] + if col in self.onehot_cols: + features = self.ordered_cats(col, n_grid_points, sort) + if index is not None or X_row is not None: + val, pred = self.get_col_value_plus_prediction(col, index, X_row) + if val not in features: + features[-1] = val + grid_values = None + elif col in self.categorical_cols: + features = col + grid_values = self.ordered_cats(col, n_grid_points, sort) + if index is not None or X_row is not None: + val, pred = self.get_col_value_plus_prediction(col, index, X_row) + if val not in grid_values: + grid_values[-1] = val else: features = col + if drop_na: + vals = np.delete(self.X[col].values, np.where(self.X[col].values==self.na_fill), axis=0) + grid_values = get_grid_points(vals, n_grid_points=n_grid_points) + else: + grid_values = get_grid_points(self.X[col].values, n_grid_points=n_grid_points) + if index is not None or X_row is not None: + val, pred = self.get_col_value_plus_prediction(col, index, X_row) + if val not in grid_values: + grid_values = np.append(grid_values, val).sort() + if pos_label is None: pos_label = self.pos_label @@ -1137,7 +1212,7 @@ def pdp_df(self, col, index=None, X_row=None, drop_na=True, elif X_row is not None: if ((len(X_row.columns) == len(self.X_cats.columns)) and (X_row.columns == self.X_cats.columns).all()): - X_row = X_cats_to_X(X_row, self.cats_dict, self.X.columns) + X_row = X_cats_to_X(X_row, self.onehot_dict, self.X.columns) else: assert (X_row.columns == self.X.columns).all(), \ "X_row should have the same columns as self.X or self.X_cats!" @@ -1163,17 +1238,14 @@ def pdp_df(self, col, index=None, X_row=None, drop_na=True, else: sampleX = self.X.sample(min(sample, len(self.X))) - # if only a single value (i.e. not onehot encoded, take that value - # instead of list): - pdp_df = get_pdp_df( model=self.model, X_sample=sampleX, - feature=features, - n_grid_points=num_grid_points, pos_label=pos_label) + feature=features, n_grid_points=n_grid_points, + pos_label=pos_label, grid_values=grid_values) if all([str(c).startswith(col+"_") for c in pdp_df.columns]): pdp_df.columns = [str(c)[len(col)+1:] for c in pdp_df.columns] - if self.is_classifier: + if self.is_classifier and self.model_output == 'probability': pdp_df = pdp_df.multiply(100) return pdp_df @@ -1196,7 +1268,7 @@ def get_dfs(self, cats=True, round=None, lang='en', pos_label=None): pd.DataFrame, pd.DataFrame, pd.DataFrame: cols_df, shap_df, contribs_df """ - if cats or self.cats_only: + if cats: cols_df = self.X_cats.copy() shap_df = pd.DataFrame(self.shap_values_cats(pos_label), columns = self.X_cats.columns) else: @@ -1303,7 +1375,7 @@ def plot_interactions(self, col, cats=False, topx=None, pos_label=None): plotly.fig: fig """ - if col in self.cats or self.cats_only: + if col in self.onehot_cols: cats = True interactions_df = self.interactions_df(col, cats=cats, topx=topx, pos_label=pos_label) title = f"Average interaction shap values for {col}" @@ -1380,7 +1452,7 @@ def plot_shap_summary(self, index=None, topx=None, cats=False, pos_label=None): else: title = f"Impact of Feature on Prediction
(SHAP values)" - if cats or self.cats_only: + if cats: return plotly_shap_scatter_plot( self.shap_values_cats(pos_label), self.X_cats, @@ -1419,7 +1491,7 @@ def plot_shap_interaction_summary(self, col, index=None, topx=None, cats=False, Returns: fig """ - if col in self.cats or self.cats_only: + if col in self.onehot_cols: cats = True interact_cols = self.shap_top_interactions(col, cats=cats, pos_label=pos_label) if topx is None: topx = len(interact_cols) @@ -1431,7 +1503,8 @@ def plot_shap_interaction_summary(self, col, index=None, topx=None, cats=False, idxs=self.idxs.values, highlight_index=index, na_fill=self.na_fill, index_name=self.index_name) - def plot_shap_dependence(self, col, color_col=None, highlight_index=None, pos_label=None): + def plot_shap_dependence(self, col, color_col=None, highlight_index=None, + topx=None, sort='alphabet', pos_label=None): """plot shap dependence Plots a shap dependence plot: @@ -1442,8 +1515,13 @@ def plot_shap_dependence(self, col, color_col=None, highlight_index=None, pos_la col(str): feature to be displayed color_col(str): if color_col provided then shap values colored (blue-red) according to feature color_col (Default value = None) - highlight_idx: individual observation to be highlighed in the plot. + highlight_index: individual observation to be highlighed in the plot. (Default value = None) + topx (int, optional): for categorical features only display topx + categories. + sort (str): for categorical features, how to sort the categories: + alphabetically 'alphabet', most frequent first 'freq', + highest mean absolute value first 'shap'. Defaults to 'alphabet'. pos_label: positive class (Default value = None) Returns: @@ -1453,7 +1531,8 @@ def plot_shap_dependence(self, col, color_col=None, highlight_index=None, pos_la highlight_idx = self.get_int_idx(highlight_index) if cats: - if col in self.cats: + + if col in self.onehot_cols or col in self.categorical_cols: return plotly_shap_violin_plot( self.X_cats, self.shap_values_cats(pos_label), @@ -1461,7 +1540,8 @@ def plot_shap_dependence(self, col, color_col=None, highlight_index=None, pos_la color_col, highlight_index=highlight_idx, idxs=self.idxs.values, - index_name=self.index_name) + index_name=self.index_name, + cats_order=self.ordered_cats(col, topx, sort)) else: return plotly_dependence_plot( self.X_cats, @@ -1474,19 +1554,30 @@ def plot_shap_dependence(self, col, color_col=None, highlight_index=None, pos_la idxs=self.idxs.values, index_name=self.index_name) else: - return plotly_dependence_plot( - self.X, - self.shap_values(pos_label), - col, - color_col, - na_fill=self.na_fill, - units=self.units, - highlight_index=highlight_idx, - idxs=self.idxs.values, - index_name=self.index_name) + if col in self.categorical_cols: + return plotly_shap_violin_plot( + self.X_cats, + self.shap_values_cats(pos_label), + col, + color_col, + highlight_index=highlight_idx, + idxs=self.idxs.values, + index_name=self.index_name, + cats_order=self.ordered_cats(col, topx, sort)) + else: + return plotly_dependence_plot( + self.X, + self.shap_values(pos_label), + col, + color_col, + na_fill=self.na_fill, + units=self.units, + highlight_index=highlight_idx, + idxs=self.idxs.values, + index_name=self.index_name) def plot_shap_interaction(self, col, interact_col, highlight_index=None, - pos_label=None): + topx=10, sort='alphabet', pos_label=None): """plots a dependence plot for shap interaction effects Args: @@ -1502,13 +1593,13 @@ def plot_shap_interaction(self, col, interact_col, highlight_index=None, cats = self.check_cats(col, interact_col) highlight_idx = self.get_int_idx(highlight_index) - if cats and interact_col in self.cats: + if cats and (interact_col in self.onehot_cols or interact_col in self.categorical_cols): return plotly_shap_violin_plot( self.X_cats, self.shap_interaction_values_by_col(col, cats, pos_label=pos_label), interact_col, col, interaction=True, units=self.units, highlight_index=highlight_idx, idxs=self.idxs.values, - index_name=self.index_name) + index_name=self.index_name, cats_order=self.ordered_cats(interact_col, topx, sort)) else: return plotly_dependence_plot(self.X_cats if cats else self.X, self.shap_interaction_values_by_col(col, cats, pos_label=pos_label), @@ -1517,7 +1608,8 @@ def plot_shap_interaction(self, col, interact_col, highlight_index=None, index_name=self.index_name) def plot_pdp(self, col, index=None, X_row=None, drop_na=True, sample=100, - gridlines=100, gridpoints=10, pos_label=None): + gridlines=100, gridpoints=10, sort='freq', round=2, + pos_label=None): """plot partial dependence plot (pdp) returns plotly fig for a partial dependence plot showing ice lines @@ -1538,6 +1630,10 @@ def plot_pdp(self, col, index=None, X_row=None, drop_na=True, sample=100, defaults to 100 gridpoints(ints: int, optional): number of points on the x axis to calculate the pdp for, defaults to 10 + sort (str, optional): For categorical features: how to sort: + 'alphabet', 'freq', 'shap'. Defaults to 'freq'. + round (int, optional): round float prediction to number of digits. + Defaults to 2. pos_label: (Default value = None) Returns: @@ -1545,29 +1641,26 @@ def plot_pdp(self, col, index=None, X_row=None, drop_na=True, sample=100, """ pdp_df = self.pdp_df(col, index, X_row, - drop_na=drop_na, sample=sample, num_grid_points=gridpoints, pos_label=pos_label) + drop_na=drop_na, sample=sample, n_grid_points=gridpoints, + pos_label=pos_label, sort=sort) units = "Predicted %" if self.model_output=='probability' else self.units - if index is not None: - col_value, pred = self.get_col_value_plus_prediction(col, index=index, pos_label=pos_label) - return plotly_pdp(pdp_df, - display_index=0, # the idx to be displayed is always set to the first row by self.pdp_df() - index_feature_value=col_value, index_prediction=pred, - feature_name=col, - num_grid_lines=min(gridlines, sample, len(self.X)), - target=self.target, units=units) - elif X_row is not None: - col_value, pred = self.get_col_value_plus_prediction(col, X_row=X_row, pos_label=pos_label) + if index is not None or X_row is not None: + col_value, pred = self.get_col_value_plus_prediction(col, index=index, X_row=X_row, pos_label=pos_label) + if (col in self.cat_cols + and col_value not in pdp_df.columns + and col_value[len(col)+1:] in pdp_df.columns): + col_value = col_value[len(col)+1:] return plotly_pdp(pdp_df, display_index=0, # the idx to be displayed is always set to the first row by self.pdp_df() - index_feature_value=col_value, index_prediction=pred, + index_feature_value=col_value, + index_prediction=pred, feature_name=col, num_grid_lines=min(gridlines, sample, len(self.X)), - target=self.target, units=units) - + round=round, target=self.target, units=units) else: return plotly_pdp(pdp_df, feature_name=col, num_grid_lines=min(gridlines, sample, len(self.X)), - target=self.target, units=units) + round=round, target=self.target, units=units) class ClassifierExplainer(BaseExplainer): @@ -1607,7 +1700,7 @@ def __init__(self, model, X, y=None, permutation_metric=roc_auc_score, self._params_dict = {**self._params_dict, **dict( labels=labels, pos_label=pos_label)} - if self.cats_only and model_output == 'probability': + if self.categorical_cols and model_output == 'probability': print("Warning: Models that deal with categorical features directly " f"such as {self.model.__class__.__name__} are incompatible with model_output='probability'" " for now. So setting model_output='logodds'...", flush=True) @@ -1818,16 +1911,16 @@ def permutation_importances_cats(self): """permutation importances with categoricals grouped""" if not hasattr(self, '_perm_imps_cats'): print("Calculating categorical permutation importances (if slow, try setting n_jobs parameter)...", flush=True) - if self.cats_only: + if self.onehot_cols: + self._perm_imps_cats = [cv_permutation_importances( + self.model, self.X, self.y, self.metric, + onehot_dict=self.onehot_dict, + cv=self.permutation_cv, + needs_proba=self.is_classifier, + pos_label=label) for label in range(len(self.labels))] + else: _ = self.permutation_importances self._perm_imps_cats = self._perm_imps - else: - self._perm_imps_cats = [cv_permutation_importances( - self.model, self.X, self.y, self.metric, - cats_dict=self.cats_dict, - cv=self.permutation_cv, - needs_proba=self.is_classifier, - pos_label=label) for label in range(len(self.labels))] return default_list(self._perm_imps_cats, self.pos_label) @property @@ -1839,7 +1932,6 @@ def shap_base_value(self): if isinstance(self._shap_base_value, np.ndarray) and len(self._shap_base_value) == 1: self._shap_base_value = self._shap_base_value[0] if isinstance(self._shap_base_value, np.ndarray): - self._shap_base_value = list(self._shap_base_value) if len(self.labels)==2 and isinstance(self._shap_base_value, (np.floating, float)): if self.model_output == 'probability': @@ -1884,15 +1976,10 @@ def shap_values(self): def shap_values_cats(self): """SHAP values with categoricals grouped together""" if not hasattr(self, '_shap_values_cats'): - if self.cats_only: - _ = self.shap_values - self._shap_values_cats = self._shap_values - else: - _ = self.shap_values - self._shap_values_cats = [ - merge_categorical_shap_values( - self.X, sv, self.cats_dict) for sv in self._shap_values] - + _ = self.shap_values + self._shap_values_cats = [ + merge_categorical_shap_values( + self.X, sv, self.onehot_dict) for sv in self._shap_values] return default_list(self._shap_values_cats, self.pos_label) @@ -1930,7 +2017,7 @@ def shap_interaction_values_cats(self): _ = self.shap_interaction_values self._shap_interaction_values_cats = [ merge_categorical_shap_interaction_values( - siv, self.X, self.X_cats, self.cats_dict) + siv, self.X, self.X_cats, self.onehot_dict) for siv in self._shap_interaction_values] return default_list(self._shap_interaction_values_cats, self.pos_label) @@ -2117,7 +2204,7 @@ def prediction_result_df(self, index=None, X_row=None, add_star=True, logodds=Fa pred_probas = self.pred_probas_raw[int_idx, :] elif X_row is not None: if X_row.columns.tolist()==self.X_cats.columns.tolist(): - X_row = X_cats_to_X(X_row, self.cats_dict, self.X.columns) + X_row = X_cats_to_X(X_row, self.onehot_dict, self.X.columns) pred_probas = self.model.predict_proba(X_row)[0, :] preds_df = pd.DataFrame(dict( @@ -2568,7 +2655,7 @@ def prediction_result_df(self, index=None, X_row=None, round=3): elif X_row is not None: if X_row.columns.tolist()==self.X_cats.columns.tolist(): - X_row = X_cats_to_X(X_row, self.cats_dict, self.X.columns) + X_row = X_cats_to_X(X_row, self.onehot_dict, self.X.columns) assert np.all(X_row.columns==self.X.columns), \ ("The column names of X_row should match X! Instead X_row.columns" f"={X_row.columns.tolist()}...") diff --git a/setup.py b/setup.py index 322c6fd..7720079 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='explainerdashboard', - version='0.2.19.1', + version='0.2.20.0', description='explainerdashboard allows you quickly build an interactive dashboard to explain the inner workings of your machine learning model.', long_description=""" diff --git a/tests/test_cats_only.py b/tests/test_cats_only.py index 838212a..2747e26 100644 --- a/tests/test_cats_only.py +++ b/tests/test_cats_only.py @@ -3,6 +3,8 @@ import pandas as pd import numpy as np +import plotly.graph_objs as go + from catboost import CatBoostClassifier, CatBoostRegressor from explainerdashboard import RegressionExplainer, ClassifierExplainer @@ -17,21 +19,123 @@ def setUp(self): _, self.names = titanic_names() model = CatBoostRegressor(iterations=5, verbose=0).fit(X_train, y_train) - explainer = RegressionExplainer(model, X_test, y_test, - cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, - 'Deck', 'Embarked']) + explainer = RegressionExplainer(model, X_test, y_test, cats=['Deck', 'Embarked']) X_cats, y_cats = explainer.X_cats, explainer.y - model = CatBoostRegressor(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[5, 6, 7]) - self.explainer = RegressionExplainer(model, X_cats, y_cats) + model = CatBoostRegressor(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[8, 9]) + self.explainer = RegressionExplainer(model, X_cats, y_cats, cats=['Sex']) + + + def test_explainer_len(self): + self.assertEqual(len(self.explainer), self.test_len) + + def test_int_idx(self): + self.assertEqual(self.explainer.get_int_idx(self.names[0]), 0) + + def test_random_index(self): + self.assertIsInstance(self.explainer.random_index(), int) + self.assertIsInstance(self.explainer.random_index(return_str=True), str) + + def test_row_from_input(self): + input_row = self.explainer.get_row_from_input( + self.explainer.X.iloc[[0]].values.tolist()) + self.assertIsInstance(input_row, pd.DataFrame) + input_row = self.explainer.get_row_from_input( + self.explainer.X_cats.iloc[[0]].values.tolist()) + self.assertIsInstance(input_row, pd.DataFrame) + + input_row = self.explainer.get_row_from_input( + self.explainer.X_cats + [self.explainer.columns_ranked_by_shap(cats=True)] + .iloc[[0]].values.tolist(), ranked_by_shap=True) + self.assertIsInstance(input_row, pd.DataFrame) + + input_row = self.explainer.get_row_from_input( + self.explainer.X + [self.explainer.columns_ranked_by_shap(cats=False)] + .iloc[[0]].values.tolist(), ranked_by_shap=True) + self.assertIsInstance(input_row, pd.DataFrame) + + def test_prediction_result_df(self): + df = self.explainer.prediction_result_df(0) + self.assertIsInstance(df, pd.DataFrame) def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) + def test_pred_percentiles(self): + self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray) + + def test_columns_ranked_by_shap(self): + self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) + self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True), list) + + def test_equivalent_col(self): + self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Sex") + self.assertEqual(self.explainer.equivalent_col("Sex"), "Sex_female") + self.assertIsNone(self.explainer.equivalent_col("random")) + + def test_ordered_cats(self): + self.assertEqual(self.explainer.ordered_cats("Sex"), ['Sex_female', 'Sex_male']) + self.assertEqual(self.explainer.ordered_cats("Deck", topx=2, sort='alphabet'), ['Deck_A', 'Deck_B']) + + self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='freq'), list) + self.assertIsInstance(self.explainer.ordered_cats("Deck", topx=3, sort='freq'), list) + self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='shap'), list) + self.assertIsInstance(self.explainer.ordered_cats("Deck", topx=3, sort='shap'), list) + + def test_get_col(self): + self.assertIsInstance(self.explainer.get_col("Sex"), pd.Series) + self.assertEqual(self.explainer.get_col("Sex").dtype, "object") + + self.assertIsInstance(self.explainer.get_col("Age"), pd.Series) + self.assertEqual(self.explainer.get_col("Age").dtype, np.float) + def test_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances, pd.DataFrame) self.assertIsInstance(self.explainer.permutation_importances_cats, pd.DataFrame) + def test_X_cats(self): + self.assertIsInstance(self.explainer.X_cats, pd.DataFrame) + + def test_columns_cats(self): + self.assertIsInstance(self.explainer.columns_cats, list) + + def test_metrics(self): + self.assertIsInstance(self.explainer.metrics(), dict) + self.assertIsInstance(self.explainer.metrics_descriptions(), dict) + + def test_mean_abs_shap_df(self): + self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) + + def test_permutation_importances_df(self): + self.assertIsInstance(self.explainer.permutation_importances_df(), pd.DataFrame) + self.assertIsInstance(self.explainer.permutation_importances_df(topx=3), pd.DataFrame) + self.assertIsInstance(self.explainer.permutation_importances_df(cats=True), pd.DataFrame) + self.assertIsInstance(self.explainer.permutation_importances_df(cutoff=0.01), pd.DataFrame) + + def test_contrib_df(self): + self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, cats=False), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, topx=3), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, sort='high-to-low'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, sort='low-to-high'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, sort='importance'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) + + + def test_contrib_summary_df(self): + self.assertIsInstance(self.explainer.contrib_summary_df(0), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, sort='high-to-low'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, sort='low-to-high'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, sort='importance'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) + def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value, (np.floating, float)) @@ -42,24 +146,212 @@ def test_shap_values(self): self.assertIsInstance(self.explainer.shap_values, np.ndarray) self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray) - # @unittest.expectedFailure - # def test_shap_interaction_values(self): - # self.assertIsInstance(self.explainer.shap_interaction_values, np.ndarray) - # self.assertIsInstance(self.explainer.shap_interaction_values_cats, np.ndarray) - def test_mean_abs_shap(self): self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame) self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame) def test_calculate_properties(self): - self.explainer.calculate_properties(include_interactions=False) + self.explainer.calculate_properties() def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame) - self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame) + self.assertIsInstance(self.explainer.pdp_df("Sex"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame) - self.assertIsInstance(self.explainer.pdp_df("Gender", index=0), pd.DataFrame) + self.assertIsInstance(self.explainer.pdp_df("Sex", index=0), pd.DataFrame) + + def test_get_dfs(self): + cols_df, shap_df, contribs_df = self.explainer.get_dfs() + self.assertIsInstance(cols_df, pd.DataFrame) + self.assertIsInstance(shap_df, pd.DataFrame) + self.assertIsInstance(contribs_df, pd.DataFrame) + + def test_plot_importances(self): + fig = self.explainer.plot_importances() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_importances(kind='permutation') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_importances(topx=3) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_importances(cats=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_shap_summary(self): + fig = self.explainer.plot_shap_summary() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_summary(topx=3) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_summary(cats=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_shap_dependence(self): + fig = self.explainer.plot_shap_dependence("Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Sex") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Age", "Sex") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Sex_female", "Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Age", highlight_index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Sex", highlight_index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="freq") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="shap") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", sort="freq") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", sort="shap") + self.assertIsInstance(fig, go.Figure) + + def test_plot_shap_contributions(self): + fig = self.explainer.plot_shap_contributions(0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, cats=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, topx=3) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, sort='high-to-low') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, sort='low-to-high') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, sort='importance') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(X_row=self.explainer.X.iloc[[0]]) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(X_row=self.explainer.X_cats.iloc[[0]]) + self.assertIsInstance(fig, go.Figure) + + def test_plot_pdp(self): + fig = self.explainer.plot_pdp("Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Sex") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Sex", index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Age", index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]]) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X_cats.iloc[[0]]) + self.assertIsInstance(fig, go.Figure) + + def test_yaml(self): + yaml = self.explainer.to_yaml() + self.assertIsInstance(yaml, str) + + def test_residuals(self): + self.assertIsInstance(self.explainer.residuals, pd.Series) + + def test_prediction_result_markdown(self): + result_index = self.explainer.prediction_result_markdown(0) + self.assertIsInstance(result_index, str) + result_name = self.explainer.prediction_result_markdown(self.names[0]) + self.assertIsInstance(result_name, str) + + def test_metrics(self): + metrics_dict = self.explainer.metrics() + self.assertIsInstance(metrics_dict, dict) + self.assertTrue('root_mean_squared_error' in metrics_dict) + self.assertTrue('mean_absolute_error' in metrics_dict) + self.assertTrue('R-squared' in metrics_dict) + self.assertIsInstance(self.explainer.metrics_descriptions(), dict) + + def test_plot_predicted_vs_actual(self): + fig = self.explainer.plot_predicted_vs_actual(logs=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_predicted_vs_actual(logs=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_predicted_vs_actual(log_x=True, log_y=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_residuals(self): + fig = self.explainer.plot_residuals() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals(vs_actual=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals(residuals='ratio') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals(residuals='log-ratio') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals(residuals='log-ratio', vs_actual=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_residuals_vs_feature(self): + fig = self.explainer.plot_residuals_vs_feature("Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals_vs_feature("Age", residuals='log-ratio') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals_vs_feature("Age", dropna=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals_vs_feature("Sex", points=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_residuals_vs_feature("Sex", winsor=10) + self.assertIsInstance(fig, go.Figure) + + def test_plot_y_vs_feature(self): + fig = self.explainer.plot_y_vs_feature("Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_y_vs_feature("Age", dropna=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_y_vs_feature("Sex", points=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_y_vs_feature("Sex", winsor=10) + self.assertIsInstance(fig, go.Figure) + + def test_plot_preds_vs_feature(self): + fig = self.explainer.plot_preds_vs_feature("Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_preds_vs_feature("Age", dropna=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_preds_vs_feature("Sex", points=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_preds_vs_feature("Sex", winsor=10) + self.assertIsInstance(fig, go.Figure) class CatBoostClassifierTests(unittest.TestCase): @@ -70,24 +362,128 @@ def setUp(self): model = CatBoostClassifier(iterations=100, verbose=0).fit(X_train, y_train) explainer = ClassifierExplainer( model, X_test, y_test, - cats=[{'Gender': ['Sex_female', 'Sex_male', 'Sex_nan']}, - 'Deck', 'Embarked'], - labels=['Not survived', 'Survived'], - idxs=test_names) + cats=['Deck', 'Embarked'], + labels=['Not survived', 'Survived']) X_cats, y_cats = explainer.X_cats, explainer.y - model = CatBoostClassifier(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[5, 6, 7]) - self.explainer = ClassifierExplainer(model, X_cats, y_cats) + model = CatBoostClassifier(iterations=5, verbose=0).fit(X_cats, y_cats, cat_features=[8, 9]) + self.explainer = ClassifierExplainer(model, X_cats, y_cats, + cats=['Sex'], labels=['Not survived', 'Survived']) + + def test_explainer_len(self): + self.assertEqual(len(self.explainer), len(titanic_survive()[2])) + + def test_int_idx(self): + self.assertEqual(self.explainer.get_int_idx(titanic_names()[1][0]), 0) + + def test_random_index(self): + self.assertIsInstance(self.explainer.random_index(), int) + self.assertIsInstance(self.explainer.random_index(return_str=True), str) + + def test_ordered_cats(self): + self.assertEqual(self.explainer.ordered_cats("Sex"), ['Sex_female', 'Sex_male']) + self.assertEqual(self.explainer.ordered_cats("Deck", topx=2, sort='alphabet'), ['Deck_A', 'Deck_B']) + + self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='freq'), list) + self.assertIsInstance(self.explainer.ordered_cats("Deck", topx=3, sort='freq'), list) + self.assertIsInstance(self.explainer.ordered_cats("Deck", sort='shap'), list) + self.assertIsInstance(self.explainer.ordered_cats("Deck", topx=3, sort='shap'), list) + def test_preds(self): self.assertIsInstance(self.explainer.preds, np.ndarray) - def test_pred_probas(self): - self.assertIsInstance(self.explainer.pred_probas, np.ndarray) + def test_row_from_input(self): + input_row = self.explainer.get_row_from_input( + self.explainer.X.iloc[[0]].values.tolist()) + self.assertIsInstance(input_row, pd.DataFrame) + + input_row = self.explainer.get_row_from_input( + self.explainer.X_cats.iloc[[0]].values.tolist()) + self.assertIsInstance(input_row, pd.DataFrame) + + input_row = self.explainer.get_row_from_input( + self.explainer.X_cats + [self.explainer.columns_ranked_by_shap(cats=True)] + .iloc[[0]].values.tolist(), ranked_by_shap=True) + self.assertIsInstance(input_row, pd.DataFrame) + + input_row = self.explainer.get_row_from_input( + self.explainer.X + [self.explainer.columns_ranked_by_shap(cats=False)] + .iloc[[0]].values.tolist(), ranked_by_shap=True) + self.assertIsInstance(input_row, pd.DataFrame) + + def test_pred_percentiles(self): + self.assertIsInstance(self.explainer.pred_percentiles, np.ndarray) + + def test_columns_ranked_by_shap(self): + self.assertIsInstance(self.explainer.columns_ranked_by_shap(), list) + self.assertIsInstance(self.explainer.columns_ranked_by_shap(cats=True), list) + + def test_equivalent_col(self): + self.assertEqual(self.explainer.equivalent_col("Sex_female"), "Sex") + self.assertEqual(self.explainer.equivalent_col("Sex"), "Sex_female") + self.assertIsNone(self.explainer.equivalent_col("random")) + + def test_get_col(self): + self.assertIsInstance(self.explainer.get_col("Sex"), pd.Series) + self.assertEqual(self.explainer.get_col("Sex").dtype, "object") + + self.assertIsInstance(self.explainer.get_col("Deck"), pd.Series) + self.assertEqual(self.explainer.get_col("Deck").dtype, "object") + + self.assertIsInstance(self.explainer.get_col("Age"), pd.Series) + self.assertEqual(self.explainer.get_col("Age").dtype, np.float) def test_permutation_importances(self): self.assertIsInstance(self.explainer.permutation_importances, pd.DataFrame) self.assertIsInstance(self.explainer.permutation_importances_cats, pd.DataFrame) + + def test_X_cats(self): + self.assertIsInstance(self.explainer.X_cats, pd.DataFrame) + + def test_columns_cats(self): + self.assertIsInstance(self.explainer.columns_cats, list) + + def test_metrics(self): + self.assertIsInstance(self.explainer.metrics(), dict) + + def test_mean_abs_shap_df(self): + self.assertIsInstance(self.explainer.mean_abs_shap_df(), pd.DataFrame) + + def test_top_interactions(self): + self.assertIsInstance(self.explainer.shap_top_interactions("Age"), list) + self.assertIsInstance(self.explainer.shap_top_interactions("Age", topx=4), list) + self.assertIsInstance(self.explainer.shap_top_interactions("Age", cats=True), list) + self.assertIsInstance(self.explainer.shap_top_interactions("Sex", cats=True), list) + + def test_permutation_importances_df(self): + self.assertIsInstance(self.explainer.permutation_importances_df(), pd.DataFrame) + self.assertIsInstance(self.explainer.permutation_importances_df(topx=3), pd.DataFrame) + self.assertIsInstance(self.explainer.permutation_importances_df(cats=True), pd.DataFrame) + self.assertIsInstance(self.explainer.permutation_importances_df(cutoff=0.01), pd.DataFrame) + + def test_contrib_df(self): + self.assertIsInstance(self.explainer.contrib_df(0), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, cats=False), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, topx=3), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, sort='high-to-low'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, sort='low-to-high'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(0, sort='importance'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_df(X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) + + def test_contrib_summary_df(self): + self.assertIsInstance(self.explainer.contrib_summary_df(0), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, cats=False), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, topx=3), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, round=3), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, sort='low-to-high'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, sort='high-to-low'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(0, sort='importance'), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) + self.assertIsInstance(self.explainer.contrib_summary_df(X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) def test_shap_base_value(self): self.assertIsInstance(self.explainer.shap_base_value, (np.floating, float)) @@ -99,28 +495,154 @@ def test_shap_values(self): self.assertIsInstance(self.explainer.shap_values, np.ndarray) self.assertIsInstance(self.explainer.shap_values_cats, np.ndarray) - # @unittest.expectedFailure - # def test_shap_interaction_values(self): - # self.assertIsInstance(self.explainer.shap_interaction_values, np.ndarray) - # self.assertIsInstance(self.explainer.shap_interaction_values_cats, np.ndarray) - def test_mean_abs_shap(self): self.assertIsInstance(self.explainer.mean_abs_shap, pd.DataFrame) self.assertIsInstance(self.explainer.mean_abs_shap_cats, pd.DataFrame) def test_calculate_properties(self): - self.explainer.calculate_properties(include_interactions=False) + self.explainer.calculate_properties() + + def test_prediction_result_df(self): + df = self.explainer.prediction_result_df(0) + self.assertIsInstance(df, pd.DataFrame) def test_pdp_df(self): self.assertIsInstance(self.explainer.pdp_df("Age"), pd.DataFrame) - self.assertIsInstance(self.explainer.pdp_df("Gender"), pd.DataFrame) + self.assertIsInstance(self.explainer.pdp_df("Sex"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Deck"), pd.DataFrame) self.assertIsInstance(self.explainer.pdp_df("Age", index=0), pd.DataFrame) - self.assertIsInstance(self.explainer.pdp_df("Gender", index=0), pd.DataFrame) + self.assertIsInstance(self.explainer.pdp_df("Sex_male", index=0), pd.DataFrame) + self.assertIsInstance(self.explainer.pdp_df("Age", X_row=self.explainer.X.iloc[[0]]), pd.DataFrame) + self.assertIsInstance(self.explainer.pdp_df("Age", X_row=self.explainer.X_cats.iloc[[0]]), pd.DataFrame) + + def test_plot_importances(self): + fig = self.explainer.plot_importances() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_importances(kind='permutation') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_importances(topx=3) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_importances(cats=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_shap_contributions(self): + fig = self.explainer.plot_shap_contributions(0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, cats=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, topx=3) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, cutoff=0.05) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, sort='high-to-low') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, sort='low-to-high') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(0, sort='importance') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(X_row=self.explainer.X.iloc[[0]], sort='importance') + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_contributions(X_row=self.explainer.X_cats.iloc[[0]], sort='importance') + self.assertIsInstance(fig, go.Figure) + + def test_plot_shap_summary(self): + fig = self.explainer.plot_shap_summary() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_summary(topx=3) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_summary(cats=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_shap_dependence(self): + fig = self.explainer.plot_shap_dependence("Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Sex_female") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Age", "Sex") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Sex_female", "Age") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Age", highlight_index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Sex", highlight_index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="freq") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", topx=3, sort="shap") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", sort="freq") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_shap_dependence("Deck", sort="shap") + self.assertIsInstance(fig, go.Figure) + + def test_plot_pdp(self): + fig = self.explainer.plot_pdp("Age") + self.assertIsInstance(fig, go.Figure) + fig = self.explainer.plot_pdp("Sex") + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Sex", index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Age", index=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X.iloc[[0]]) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pdp("Age", X_row=self.explainer.X_cats.iloc[[0]]) + self.assertIsInstance(fig, go.Figure) + + def test_yaml(self): + yaml = self.explainer.to_yaml() + self.assertIsInstance(yaml, str) + + def test_pos_label(self): + self.explainer.pos_label = 1 + self.explainer.pos_label = "Not survived" + self.assertIsInstance(self.explainer.pos_label, int) + self.assertIsInstance(self.explainer.pos_label_str, str) + self.assertEqual(self.explainer.pos_label, 0) + self.assertEqual(self.explainer.pos_label_str, "Not survived") + + def test_get_prop_for_label(self): + self.explainer.pos_label = 1 + tmp = self.explainer.pred_percentiles + self.explainer.pos_label = 0 + self.assertTrue(np.alltrue(self.explainer.get_prop_for_label("pred_percentiles", 1)==tmp)) + + def test_pred_probas(self): + self.assertIsInstance(self.explainer.pred_probas, np.ndarray) + + def test_metrics(self): self.assertIsInstance(self.explainer.metrics(), dict) self.assertIsInstance(self.explainer.metrics(cutoff=0.9), dict) + self.assertIsInstance(self.explainer.metrics_descriptions(cutoff=0.9), dict) + def test_precision_df(self): self.assertIsInstance(self.explainer.precision_df(), pd.DataFrame) @@ -131,4 +653,97 @@ def test_lift_curve_df(self): self.assertIsInstance(self.explainer.lift_curve_df(), pd.DataFrame) def test_prediction_result_markdown(self): - self.assertIsInstance(self.explainer.prediction_result_markdown(0), str) \ No newline at end of file + self.assertIsInstance(self.explainer.prediction_result_markdown(0), str) + + def test_calculate_properties(self): + self.explainer.calculate_properties() + + def test_plot_precision(self): + fig = self.explainer.plot_precision() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_precision(multiclass=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5) + self.assertIsInstance(fig, go.Figure) + + def test_plot_cumulative_precision(self): + fig = self.explainer.plot_cumulative_precision() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_cumulative_precision(percentile=0.5) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_cumulative_precision(percentile=0.1, pos_label=0) + self.assertIsInstance(fig, go.Figure) + + def test_plot_confusion_matrix(self): + fig = self.explainer.plot_confusion_matrix(normalized=False, binary=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_confusion_matrix(normalized=False, binary=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_confusion_matrix(normalized=True, binary=False) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_confusion_matrix(normalized=True, binary=True) + self.assertIsInstance(fig, go.Figure) + + def test_plot_lift_curve(self): + fig = self.explainer.plot_lift_curve() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_lift_curve(percentage=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_lift_curve(cutoff=0.5) + self.assertIsInstance(fig, go.Figure) + + def test_plot_lift_curve(self): + fig = self.explainer.plot_lift_curve() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_lift_curve(percentage=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_lift_curve(cutoff=0.5) + self.assertIsInstance(fig, go.Figure) + + def test_plot_classification(self): + fig = self.explainer.plot_classification() + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_classification(percentage=True) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_classification(cutoff=0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_classification(cutoff=1) + self.assertIsInstance(fig, go.Figure) + + def test_plot_roc_auc(self): + fig = self.explainer.plot_roc_auc(0.5) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_roc_auc(0.0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_roc_auc(1.0) + self.assertIsInstance(fig, go.Figure) + + def test_plot_pr_auc(self): + fig = self.explainer.plot_pr_auc(0.5) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pr_auc(0.0) + self.assertIsInstance(fig, go.Figure) + + fig = self.explainer.plot_pr_auc(1.0) + self.assertIsInstance(fig, go.Figure) + + def test_plot_prediction_result(self): + fig = self.explainer.plot_prediction_result(0) + self.assertIsInstance(fig, go.Figure) \ No newline at end of file diff --git a/tests/test_classifier_explainer.py b/tests/test_classifier_explainer.py index 7a0004d..8e7acb1 100644 --- a/tests/test_classifier_explainer.py +++ b/tests/test_classifier_explainer.py @@ -107,6 +107,9 @@ def test_plot_lift_curve(self): fig = self.explainer.plot_lift_curve(cutoff=0.5) self.assertIsInstance(fig, go.Figure) + fig = self.explainer.plot_lift_curve(add_wizard=False, round=3) + self.assertIsInstance(fig, go.Figure) + def test_plot_lift_curve(self): fig = self.explainer.plot_lift_curve() self.assertIsInstance(fig, go.Figure)