Skip to content

Commit

Permalink
fixed shadow tree tab updating bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Oege Dijk authored and Oege Dijk committed Jan 31, 2020
1 parent 415c779 commit 176ad8d
Show file tree
Hide file tree
Showing 13 changed files with 224 additions and 1,033 deletions.
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"python.pythonPath": "/anaconda/bin/python",
"python.pythonPath": "C:\\ProgramData\\Anaconda3\\envs\\ww_env2\\python.exe",
"python.testing.unittestArgs": [
"-v",
"-s",
Expand Down
8 changes: 4 additions & 4 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,18 @@
# TODO:

## Layout:
- set all tabs default to False

## Plots:
- add multiclass confusion matrix
- individual trees: highlight selected tree
- fix shap dependence summary going full width

### Regression plots:

## Explainers:

## Dashboard:
- add dependence plot to importances list
- COntributions: add div margin to index selector
- Contributions: add div margin to index selector
- move number of features to display
- add option for vertical contributions?
- reformat contributions table
Expand All @@ -26,10 +24,12 @@
- Move pdp function to explainer_methods.py
- add feature explanations


## Library level:
- fix forever updating bug (seems shadow tree related?)
- fix jupyter reload pdp bug
- just add kind='tree', 'linear', 'deep', etc
- submit pull request to dtreeviz to accept shadowtree as parameter
- just add shap='tree', 'linear', 'deep', etc instead of separate classes
- add long description to pypi: https://packaging.python.org/guides/making-a-pypi-friendly-readme/
- Add tests
- Test with lightgbm, catboost, extratrees
Expand Down
928 changes: 22 additions & 906 deletions dashboard_examples.ipynb

Large diffs are not rendered by default.

Binary file modified docs/source/_build/html/.doctrees/environment.pickle
Binary file not shown.
37 changes: 0 additions & 37 deletions explainerdashboard/TODO.md

This file was deleted.

20 changes: 9 additions & 11 deletions explainerdashboard/dashboard_tabs/model_summary_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def layout(self):
0.75: '0.75', 0.99: '0.99'},
included=False,
tooltip = {'always_visible' : True})
], style={'margin': 20}),
], style={'margin': 0}),
])
]),
dbc.Row([
Expand All @@ -247,7 +247,7 @@ def layout(self):
0.75: '0.75', 0.99: '0.99'},
included=False,
tooltip = {'always_visible' : True})
], style={'margin': 20}),
], style={'margin': 0}),
])
]),
dbc.Row([
Expand Down Expand Up @@ -280,13 +280,6 @@ def layout(self):

def register_callbacks(self, app, **kwargs):

@app.callback(
Output('precision-cutoff', 'value'),
[Input('fraction-cutoff', 'value')]
)
def update_cutoff(fraction):
return np.round(self.explainer.cutoff_fraction(fraction), 2)

@app.callback(
[Output('bin-size-div', 'style'),
Output('quantiles-div', 'style')],
Expand Down Expand Up @@ -349,8 +342,6 @@ def update_precision_graph(cutoff, percentage, pos_label, tab):
return self.explainer.plot_confusion_matrix(
cutoff=cutoff, normalized=percentage)



@app.callback(
Output('roc-auc-graph', 'figure'),
[Input('precision-cutoff', 'value'),
Expand All @@ -369,6 +360,13 @@ def update_precision_graph(cutoff, pos_label, tab):
def update_precision_graph(cutoff, pos_label, tab):
return self.explainer.plot_pr_auc(cutoff=cutoff)

@app.callback(
Output('precision-cutoff', 'value'),
[Input('fraction-cutoff', 'value')]
)
def update_cutoff(fraction):
return np.round(self.explainer.cutoff_fraction(fraction), 2)

class RegressionModelStats:
def __init__(self, explainer, round=2, logs=False, vs_actual=False, ratio=False):
self.explainer = explainer
Expand Down
109 changes: 56 additions & 53 deletions explainerdashboard/dashboard_tabs/shadow_trees_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ def shadow_trees_layout(explainer, round=2, **kwargs):
], inline=True),
dcc.Store(id='tree-index-store'),
html.H4('(click on a prediction to see decision path)'),
dcc.Loading(id="loading-trees-graph",
children=[dcc.Graph(id='tree-predictions-graph')]),
# dcc.Loading(id="loading-trees-graph",
# children=[
dcc.Graph(id='tree-predictions-graph'),
#]),
], width={"size": 8, "offset": 2})
]),
dbc.Row([
Expand All @@ -77,7 +79,7 @@ def shadow_trees_layout(explainer, round=2, **kwargs):
]),
dbc.Row([
dbc.Col([
dcc.Loading(id="loading-svg-graph",
dcc.Loading(id="loading-tree-svg-graph",
children=[html.Img(id='dtreeviz-svg')])
])

Expand All @@ -86,29 +88,41 @@ def shadow_trees_layout(explainer, round=2, **kwargs):


def shadow_trees_callbacks(explainer, app, round=2, **kwargs):
@app.callback(
Output('tree-input-index', 'value'),
[Input('tree-index-button', 'n_clicks')],
[State('tabs', 'value')]
)
def update_tree_input_index(n_clicks, tab):
return explainer.random_index(return_str=True)

@app.callback(
Output('tree-index-store', 'data'),
[Input('tree-input-index', 'value')]
)
def update_tree_index_store(index):
if (explainer.idxs is None
and str(index).isdigit()
and int(index) >= 0
and int(index) <= len(explainer)):
return int(index)
if (explainer.idxs is not None
and index in explainer.idxs):
return index
[Output('tree-basevalue', 'children'),
Output('tree-predictions-table', 'columns'),
Output('tree-predictions-table', 'data')],
[Input('tree-predictions-graph', 'clickData')],
#Input('label-store', 'data')], #this causes issues for some reason, only on this tab??
[State('tree-index-store', 'data'),
State('tabs', 'value')])
def display_tree_click_data(clickData, idx, tab):
if clickData is not None and idx is not None:
model = int(clickData['points'][0]['text'].split('tree no ')[1].split(':')[0]) if clickData is not None else 0
(baseval, prediction, shadowtree_df) = \
explainer.shadowtree_df_summary(model, idx, round=round)
columns=[{'id': c, 'name': c}
for c in shadowtree_df.columns.tolist()]
baseval_str = f"Tree no {model}, Starting prediction : {baseval}, final prediction : {prediction}"
print(baseval, columns, shadowtree_df)
return (baseval_str, columns, shadowtree_df.to_dict('records'))
raise PreventUpdate

@app.callback(
Output('dtreeviz-svg', 'src'),
[Input('tree-predictions-graph', 'clickData'),
#Input('label-store', 'data')#this causes issues for some reason, only on this tab??
],
[State('tree-index-store', 'data'),
State('tabs', 'value')])
def display_click_data(clickData, idx, tab):
if clickData is not None and idx is not None and explainer.graphviz_available:
model = int(clickData['points'][0]['text'].split('tree no ')[1].split(':')[0])
svg_encoded = explainer.decision_path_encoded(model, idx)
return svg_encoded
return ""

@app.callback(
Output('tree-predictions-graph', 'figure'),
[Input('tree-index-store', 'data'),
Expand All @@ -118,38 +132,27 @@ def update_tree_index_store(index):
def update_tree_graph(index, pos_label, tab):
if index is not None:
return explainer.plot_trees(index, round=round)
raise PreventUpdate
return {}

@app.callback(
[Output('tree-basevalue', 'children'),
Output('tree-predictions-table', 'columns'),
Output('tree-predictions-table', 'data'),],
[Input('tree-predictions-graph', 'clickData'),
Input('tree-index-store', 'data'),
Input('label-store', 'data')],
[State('tabs', 'value')])
def display_click_data(clickData, idx, pos_label, tab):
if clickData is not None and idx is not None:
model = int(clickData['points'][0]['text'].split('tree no ')[1].split(':')[0]) if clickData is not None else 0
(baseval, prediction,
shadowtree_df) = explainer.shadowtree_df_summary(model, idx, round=round)
columns=[{'id': c, 'name': c} for c in shadowtree_df.columns.tolist()]
baseval_str = f"Tree no {model}, Starting prediction : {baseval}, final prediction : {prediction}"
return (baseval_str, columns, shadowtree_df.to_dict('records'))
raise PreventUpdate

Output('tree-index-store', 'data'),
[Input('tree-input-index', 'value')],
[State('tabs', 'value')]
)
def update_tree_index_store(index, tab):
if (explainer.idxs is None
and str(index).isdigit()
and int(index) >= 0
and int(index) <= len(explainer)):
return int(index)
if (explainer.idxs is not None and index in explainer.idxs):
return index
return None

@app.callback(
Output('dtreeviz-svg', 'src'),
[Input('tree-predictions-graph', 'clickData'),
Input('tree-index-store', 'data'),
Input('label-store', 'data')],
[State('tabs', 'value')])
def display_click_data(clickData, idx, pos_label, tab):
if (clickData is not None
and idx is not None
and explainer.graphviz_available):
model = int(clickData['points'][0]['text'].split('tree no ')[1].split(':')[0])
svg_encoded = explainer.decision_path_encoded(model, idx)
return svg_encoded
raise PreventUpdate
Output('tree-input-index', 'value'),
[Input('tree-index-button', 'n_clicks')],
[State('tabs', 'value')]
)
def update_tree_input_index(n_clicks, tab):
return explainer.random_index(return_str=True)
14 changes: 7 additions & 7 deletions explainerdashboard/dashboard_tabs/shap_dependence_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def register_callbacks(self, app):
shap_dependence_callbacks(self.explainer, app)


def shap_dependence_layout(explainer, n_features=10, **kwargs):
def shap_dependence_layout(explainer, n_features=10, cats=True, **kwargs):

cats_display = 'none' if explainer.cats is None else 'inline-block'
return dbc.Container([
Expand All @@ -52,8 +52,8 @@ def shap_dependence_layout(explainer, n_features=10, **kwargs):
dbc.Label("Depth:"),
dcc.Dropdown(id='dependence-scatter-depth',
options = [{'label': str(i+1), 'value':i+1}
for i in range(len(explainer.columns)-1)],
value=min(n_features, len(explainer.columns)-1))],
for i in range(len(explainer.columns_ranked(cats))-1)],
value=min(n_features, len(explainer.columns_ranked(cats))-1))],
width=3),
dbc.Col([
dbc.FormGroup(
Expand Down Expand Up @@ -97,15 +97,15 @@ def shap_dependence_layout(explainer, n_features=10, **kwargs):
html.Label('Plot dependence for column:'),
dcc.Dropdown(id='dependence-col',
options=[{'label': col, 'value':col}
for col in explainer.columns_cats],
value=explainer.columns_cats[0])],
for col in explainer.columns_ranked(cats)],
value=explainer.columns_ranked(cats)[0])],
width=5),
dbc.Col([
html.Label('Color observation by column:'),
dcc.Dropdown(id='dependence-color-col',
options=[{'label': col, 'value':col}
for col in explainer.columns_cats],
value=explainer.columns_cats[0])],
for col in explainer.columns_ranked(cats)],
value=explainer.columns_ranked(cats)[1])],
width=5),
dbc.Col([
html.Label('Highlight:'),
Expand Down
2 changes: 1 addition & 1 deletion explainerdashboard/dashboard_tabs/shap_interactions_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def shap_interactions_layout(explainer,
dbc.Col([
dbc.Label("Feature"),
dcc.Dropdown(id='interaction-col',
options=[{'label': col, 'value':col}
options=[{'label': col, 'value': col}
for col in explainer.columns_ranked(cats)],
value=explainer.columns_ranked(cats)[0])],
width=4),
Expand Down
4 changes: 3 additions & 1 deletion explainerdashboard/dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def _insert_tabs(self):
if self.shap_interaction:
self.tabs.append(ShapInteractionsTab(self.explainer, **self.kwargs))
if self.shadow_trees:
assert hasattr(self.explainer, 'shadow_trees')
assert hasattr(self.explainer, 'shadow_trees'), \
"""the explainer object has no shadow_trees property. This tab
only works with a RandomForestClassifierBunch or RandomForestRegressionBunch"""
self.tabs.append(ShadowTreesTab(self.explainer, **self.kwargs))

def run(self, port=8050, **kwargs):
Expand Down

0 comments on commit 176ad8d

Please sign in to comment.