Skip to content

Commit

Permalink
IndexSelector: use dynamic options (#219)
Browse files Browse the repository at this point in the history
* IndexSelector: use dynamic options

* converted more index dropdowns to dynamic

* adds InputDropdownComponent

* updates IndexSelector for dynamic search

* bump sklearn to v1.1 and adjusts pipeline tests

* IndexSelector: use dynamic options

* converted more index dropdowns to dynamic

* adds InputDropdownComponent

* updates IndexSelector for dynamic search

* software eng fixup and cleanup

* casts index_list to str by default

* removes itertools import

* bump to 0.4.0

Co-authored-by: Oege Dijk <oege.dijk@fourkind.com>
  • Loading branch information
Achim Gädke and Oege Dijk committed Jun 15, 2022
1 parent f0a3427 commit c247f47
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 65 deletions.
2 changes: 1 addition & 1 deletion explainerdashboard/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

___version__ = "0.3.8.2"
___version__ = "0.4.0"

from .explainers import ClassifierExplainer, RegressionExplainer
from .dashboards import ExplainerDashboard, ExplainerHub, InlineExplainer
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, explainer, title="Select Random Index", name=None,

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'random-index-clas-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

assert (len(self.slider) == 2 and
self.slider[0] >= 0 and self.slider[0] <=1 and
Expand Down Expand Up @@ -312,7 +312,7 @@ def __init__(self, explainer, title="Prediction", name=None,
self.index_name = 'clas-prediction-index-'+self.name
self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'clas-prediction-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

if self.feature_input_component is not None:
self.exclude_callbacks(self.feature_input_component)
Expand Down
2 changes: 1 addition & 1 deletion explainerdashboard/dashboard_components/composites.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def __init__(self, explainer, title="What if...", name=None,

self.input = FeatureInputComponent(explainer, name=self.name+"0",
hide_selector=hide_selector, n_input_cols=self.n_input_cols,
**update_params(kwargs, hide_index=True))
**update_params(kwargs, hide_index=False))

if self.explainer.is_classifier:
self.index = ClassifierRandomIndexComponent(explainer, name=self.name+"1",
Expand Down
10 changes: 4 additions & 6 deletions explainerdashboard/dashboard_components/connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def __init__(self, input_index, output_indexes, explainer=None):
self.output_index_names = [self.output_index_names]

@staticmethod
def index_name(indexes):#, multi=False):
def index_name(indexes):
def get_index_name(o):
if isinstance(o, str): return o
elif isinstance(o, ExplainerComponent):
Expand All @@ -277,15 +277,13 @@ def get_index_name(o):
def component_callbacks(self, app):
@app.callback(
[Output(index_name, 'value') for index_name in self.output_index_names],
[Input(self.input_index_name, 'value')]
[Input(self.input_index_name, 'value')],
)
def update_indexes(index):
if self.explainer is not None:
if self.explainer.index_exists(index):
return tuple(index for i in range(len(self.output_index_names)))
else:
raise PreventUpdate
return tuple(index for i in range(len(self.output_index_names)))
return tuple([index for _ in self.output_index_names])
raise PreventUpdate


class HighlightConnector(ExplainerComponent):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, explainer, title="Decision Trees", name=None,

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'decisiontrees-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)


if isinstance(self.explainer, RandomForestExplainer):
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(self, explainer, title="Decision path table", name=None,

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'decisionpath-table-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

if self.description is None: self.description = """
Shows the path that an observation took down a specific decision tree.
Expand Down Expand Up @@ -359,7 +359,7 @@ def __init__(self, explainer, title="Decision path graph", name=None,

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'decisionpath-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)
self.register_dependencies("shadow_trees")

def layout(self):
Expand Down
15 changes: 8 additions & 7 deletions explainerdashboard/dashboard_components/overview_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ def __init__(self, explainer, title="Partial Dependence Plot", name=None,
"""
self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'pdp-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

self.popout = GraphPopout('pdp-'+self.name+'popout', 'pdp-graph-'+self.name, self.title, self.description)

Expand Down Expand Up @@ -631,6 +631,9 @@ def update_pdp_graph(col, drop_na, sample, gridlines, gridpoints, sort, pos_labe
gridpoints=gridpoints, sort=sort, pos_label=pos_label)





class FeatureInputComponent(ExplainerComponent):
def __init__(self, explainer, title="Feature Input", name=None,
subtitle="Adjust the feature values to change the prediction",
Expand Down Expand Up @@ -670,6 +673,7 @@ def __init__(self, explainer, title="Feature Input", name=None,
assert len(explainer.columns) == len(set(explainer.columns)), \
"Not all X column names are unique, so cannot launch FeatureInputComponent component/tab!"

self.index_input = IndexSelector(explainer, name='feature-input-index-'+self.name, **kwargs)
self.index_name = 'feature-input-index-'+self.name


Expand Down Expand Up @@ -783,11 +787,7 @@ def layout(self):
dbc.Row([
make_hideable(
dbc.Col([
dbc.Label(f"{self.explainer.index_name}:"),
dcc.Dropdown(id='feature-input-index-'+self.name,
options = [{'label': str(idx), 'value':idx}
for idx in self.explainer.get_index_list()],
value=self.index)
self.index_input.layout()
], md=4), hide=self.hide_index),
]),
input_row,
Expand All @@ -809,9 +809,10 @@ def to_html(self, state_dict=None, add_header=True):
return html

def component_callbacks(self, app):

@app.callback(
[*self._feature_callback_outputs],
[Input('feature-input-index-'+self.name, 'value')]
[Input(self.index_name, 'value')]
)
def update_whatif_inputs(index):
if index is None or not self.explainer.index_exists(index):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, explainer, title=None, name=None,

self.index_name = 'random-index-reg-index-'+self.name
self.index_selector = IndexSelector(explainer, self.index_name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

if self.explainer.y_missing:
self.hide_residual_slider = True
Expand Down
39 changes: 19 additions & 20 deletions explainerdashboard/dashboard_components/shap_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
'InteractionSummaryDependenceConnector',
'ShapContributionsTableComponent',
'ShapContributionsGraphComponent']


import dash
from dash import html, dcc, Input, Output, State
Expand Down Expand Up @@ -68,6 +69,8 @@ def __init__(self, explainer, title='Shap Summary', name=None,
if self.depth is not None:
self.depth = min(self.depth, self.explainer.n_features)

self.index_selector = IndexSelector(explainer, 'shap-summary-index-'+self.name,
index=index, **kwargs)
self.index_name = 'shap-summary-index-'+self.name
self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
assert self.summary_type in {'aggregate', 'detailed'}
Expand Down Expand Up @@ -132,12 +135,7 @@ def layout(self):
dbc.Tooltip(f"Select {self.explainer.index_name} to highlight in plot. "
"You can also select by clicking on a scatter point in the graph.",
target='shap-summary-index-label-'+self.name),
dcc.Dropdown(id='shap-summary-index-'+self.name,
options = [{'label': str(idx), 'value':idx}
for idx in self.explainer.idxs],
optionHeight=12,
style={'height': '12', 'font-size': '12'},
value=self.index),
self.index_selector.layout()
], id='shap-summary-index-col-'+self.name, style=dict(display="none")),
], md=3), hide=self.hide_index),
make_hideable(
Expand Down Expand Up @@ -174,6 +172,7 @@ def to_html(self, state_dict=None, add_header=True):
return html

def component_callbacks(self, app):

@app.callback(
Output('shap-summary-index-'+self.name, 'value'),
[Input('shap-summary-graph-'+self.name, 'clickData')])
Expand Down Expand Up @@ -288,6 +287,8 @@ def __init__(self, explainer, title='Shap Dependence', name=None,

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)

self.index_selector = IndexSelector(explainer, 'shap-dependence-index-'+self.name,
index=index, **kwargs)
self.index_name = 'shap-dependence-index-'+self.name

if self.description is None: self.description = """
Expand Down Expand Up @@ -350,10 +351,7 @@ def layout(self):
"You can also select by clicking on a scatter marker in the accompanying"
" shap summary plot (detailed).",
target='shap-dependence-index-label-'+self.name),
dcc.Dropdown(id='shap-dependence-index-'+self.name,
options = [{'label': str(idx), 'value':idx}
for idx in self.explainer.idxs],
value=self.index)
self.index_selector.layout(),
], md=4), hide=self.hide_index),
]),
dcc.Loading(id="loading-dependence-graph-"+self.name,
Expand Down Expand Up @@ -435,6 +433,7 @@ def to_html(self, state_dict=None, add_header=True):
return html

def component_callbacks(self, app):

@app.callback(
[Output('shap-dependence-color-col-'+self.name, 'options'),
Output('shap-dependence-color-col-'+self.name, 'value'),
Expand Down Expand Up @@ -566,6 +565,9 @@ def __init__(self, explainer, title="Interactions Summary", name=None,
self.col = self.explainer.columns_ranked_by_shap()[0]
if self.depth is not None:
self.depth = min(self.depth, self.explainer.n_features-1)

self.index_selector = IndexSelector(explainer, 'interaction-summary-index-'+self.name,
index=index, **kwargs)
self.index_name = 'interaction-summary-index-'+self.name
self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)

Expand Down Expand Up @@ -641,10 +643,7 @@ def layout(self):
dbc.Tooltip(f"Select {self.explainer.index_name} to highlight in plot. "
"You can also select by clicking on a scatter point in the graph.",
target='interaction-summary-index-label-'+self.name),
dcc.Dropdown(id='interaction-summary-index-'+self.name,
options = [{'label': str(idx), 'value':idx}
for idx in self.explainer.idxs],
value=self.index),
self.index_selector.layout(),
], id='interaction-summary-index-col-'+self.name, style=dict(display="none")),
], md=3), hide=self.hide_index),
make_hideable(
Expand Down Expand Up @@ -793,6 +792,9 @@ def __init__(self, explainer, title="Interaction Dependence", name=None,
if self.interact_col is None:
self.interact_col = explainer.top_shap_interactions(self.col)[1]

self.index_selector = IndexSelector(explainer, 'interaction-dependence-index-'+self.name,
index=index, **kwargs)
self.index_name = 'interaction-dependence-index-'+self.name

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.popout_top = GraphPopout(self.name+'popout-top',
Expand Down Expand Up @@ -852,10 +854,7 @@ def layout(self):
"You can also select by clicking on a scatter marker in the accompanying"
" shap interaction summary plot (detailed).",
target='interaction-dependence-index-label-'+self.name),
dcc.Dropdown(id='interaction-dependence-index-'+self.name,
options = [{'label': str(idx), 'value':idx}
for idx in self.explainer.idxs],
value=self.index)
self.index_selector.layout(),
], md=4), hide=self.hide_index),
]),
dbc.Row([
Expand Down Expand Up @@ -1168,7 +1167,7 @@ def __init__(self, explainer, title="Contributions Plot", name=None,

self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'contributions-graph-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

self.popout = GraphPopout('contributions-graph-'+self.name+'popout',
'contributions-graph-'+self.name, self.title, self.description)
Expand Down Expand Up @@ -1398,7 +1397,7 @@ def __init__(self, explainer, title="Contributions Table", name=None,
"""
self.selector = PosLabelSelector(explainer, name=self.name, pos_label=pos_label)
self.index_selector = IndexSelector(explainer, 'contributions-table-index-'+self.name,
index=index, index_dropdown=index_dropdown)
index=index, index_dropdown=index_dropdown, **kwargs)

self.register_dependencies('shap_values_df')

Expand Down
63 changes: 52 additions & 11 deletions explainerdashboard/dashboard_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

import dash
from dash import html, dcc, Input, Output, State
from dash.exceptions import PreventUpdate

import dash_bootstrap_components as dbc

Expand Down Expand Up @@ -538,22 +537,64 @@ def layout(self):

class IndexSelector(ExplainerComponent):
"""Either shows a dropdown or a free text input field for selecting an index"""
def __init__(self, explainer, name=None, index=None, index_dropdown=True):
super().__init__(explainer)
def __init__(self, explainer, name:str=None, index:str=None, index_dropdown:bool=True, max_idxs_in_dropdown:int=1000, **kwargs):
"""generates an index selector, either (dynamic) dropdown or free text field with input checker
Args:
explainer (BaseExplainer): explainer
name (str, optional): dash id to assign to the component. Defaults to None, in which case a unique identifier gets generated.
index (str, optional): initial index to select and display. Defaults to None.
index_dropdown (bool, optional): if set to false, input is an open text input instead of a dropdown. Defaults to True.
max_idxs_in_dropdown (int, optional): If the number of rows (idxs) in X_test is larger than this,
use a servers-side dynamically updating set of dropdown options instead of storing all index
options client side. Defaults to 1000.
"""
super().__init__(explainer, name=name)

def layout(self):
if self.index_dropdown:
return dcc.Dropdown(id=self.name,
options = [{'label': str(idx), 'value': str(idx)} for idx in self.explainer.get_index_list()],
placeholder=f"Select {self.explainer.index_name} here...",
value=self.index
)
index_list = self.explainer.get_index_list()
if len(index_list) > self.max_idxs_in_dropdown:
return dcc.Dropdown(
id=self.name,
placeholder=f"Search {self.explainer.index_name} here...",
value=self.index,
clearable=False
)
else:
return dcc.Dropdown(
id=self.name,
placeholder=f"Select {self.explainer.index_name} here...",
options = index_list.astype(str).to_list(),
value=self.index,
)
else:
return dbc.Input(id=self.name, placeholder=f"Type {self.explainer.index_name} here...",
value=self.index, debounce=True, type="text")
return dbc.Input(
id=self.name,
placeholder=f"Type {self.explainer.index_name} here...",
value=self.index,
debounce=True,
type="text"
)

def component_callbacks(self, app):
if not self.index_dropdown:
if self.index_dropdown:
if len(self.explainer.get_index_list()) > self.max_idxs_in_dropdown:
@app.callback(
Output(self.name, "options"),
Input(self.name, "search_value"),
Input(self.name, "value")
)
def update_options(search_value, index):
trigger_prop = dash.callback_context.triggered[0]['prop_id'].split('.')[-1]
if trigger_prop == 'value':
return [index]
index_list = self.explainer.get_index_list()
if search_value:
index_list = index_list[index_list.str.contains(search_value, case=False)]
index_list = index_list[:self.max_idxs_in_dropdown].tolist()
return index_list
else:
@app.callback(
[Output(self.name, 'valid'),
Output(self.name, 'invalid')],
Expand Down

0 comments on commit c247f47

Please sign in to comment.