Skip to content

Commit

Permalink
updated pos_label in all callbacks to deal with multiple workers
Browse files Browse the repository at this point in the history
  • Loading branch information
oegesam committed Feb 2, 2020
1 parent 1142f43 commit 1c394a9
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 558 deletions.
4 changes: 4 additions & 0 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,19 @@
- add multiclass confusion matrix
- individual trees: highlight selected tree
- Add feature names to waterfall plot
- fix percentages difference bug lift plot vs classification plot

### Regression plots:

## Explainers:
- add `get_metrics_dict()` function

## Dashboard:
- FIX NOT LOADING UNTIL CLICKED ON CONTRIBUTIONS TAB BUG!
- add option for vertical contributions?
- reformat contributions table
- add final prediction to contributions table
-

## Methods:

Expand Down
646 changes: 90 additions & 556 deletions dashboard_examples.ipynb

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions explainerdashboard/dashboard_tabs/contributions_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def display_value(contributions_size):
def update_output_div(index, topx, pos_label):
if index is None:
raise PreventUpdate
explainer.pos_label = pos_label #needed in case of multiple workers
int_idx = explainer.get_int_idx(index)
if explainer.is_classifier:
def display_probas(pred_probas_raw, labels, round=2):
Expand Down Expand Up @@ -308,4 +309,5 @@ def update_pdp_col(clickData):
Input('label-store', 'data')]
)
def update_pdp_graph(idx, col, pos_label):
explainer.pos_label = pos_label #needed in case of multiple workers
return explainer.plot_pdp(col, idx, sample=100)
12 changes: 11 additions & 1 deletion explainerdashboard/dashboard_tabs/model_summary_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def register_callbacks(self, app, **kwargs):
[State('tabs', 'value')]
)
def update_importances(tablesize, cats, permutation_shap, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_importances(
type=permutation_shap, topx=tablesize, cats=cats)

Expand Down Expand Up @@ -304,6 +305,7 @@ def update_div_visibility(bins_or_quantiles):
[State('tabs', 'value')],
)
def update_precision_graph(percentage, cutoff, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_lift_curve(cutoff=cutoff, percentage=percentage)

@app.callback(
Expand All @@ -317,6 +319,7 @@ def update_precision_graph(percentage, cutoff, pos_label, tab):
[State('tabs', 'value')],
)
def update_precision_graph(bin_size, quantiles, bins, cutoff, multiclass, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
if bins=='bin_size':
return self.explainer.plot_precision(
bin_size=bin_size, cutoff=cutoff, multiclass=multiclass)
Expand All @@ -333,6 +336,7 @@ def update_precision_graph(bin_size, quantiles, bins, cutoff, multiclass, pos_la
[State('tabs', 'value')],
)
def update_precision_graph(percentage, cutoff, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_classification(cutoff=cutoff, percentage=percentage)

@app.callback(
Expand All @@ -343,6 +347,7 @@ def update_precision_graph(percentage, cutoff, pos_label, tab):
[State('tabs', 'value')],
)
def update_precision_graph(cutoff, percentage, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_confusion_matrix(
cutoff=cutoff, normalized=percentage)

Expand All @@ -353,6 +358,7 @@ def update_precision_graph(cutoff, percentage, pos_label, tab):
Input('tabs', 'value')],
)
def update_precision_graph(cutoff, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_roc_auc(cutoff=cutoff)

@app.callback(
Expand All @@ -362,6 +368,7 @@ def update_precision_graph(cutoff, pos_label, tab):
[State('tabs', 'value')],
)
def update_precision_graph(cutoff, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_pr_auc(cutoff=cutoff)

@app.callback(
Expand Down Expand Up @@ -452,13 +459,13 @@ def layout(self):


def register_callbacks(self, app, **kwargs):

@app.callback(
Output('model-summary', 'children'),
[Input('label-store', 'data')],
[State('tabs', 'value')]
)
def update_model_summary(pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
rmse = np.round(np.sqrt(mean_squared_error(self.explainer.y, self.explainer.preds)), 2)
mae = np.round(mean_absolute_error(self.explainer.y, self.explainer.preds), 2)
r2 = np.round(r2_score(self.explainer.y, self.explainer.preds), 2)
Expand All @@ -480,6 +487,7 @@ def update_model_summary(pos_label, tab):
[State('tabs', 'value')]
)
def update_predicted_vs_actual_graph(logs, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_predicted_vs_actual(logs=logs)

@app.callback(
Expand All @@ -490,6 +498,7 @@ def update_predicted_vs_actual_graph(logs, pos_label, tab):
[State('tabs', 'value')],
)
def update_residuals_graph(pred_or_actual, ratio, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
vs_actual = pred_or_actual=='vs_actual'
return self.explainer.plot_residuals(vs_actual=vs_actual, ratio=ratio)

Expand All @@ -501,6 +510,7 @@ def update_residuals_graph(pred_or_actual, ratio, pos_label, tab):
[State('tabs', 'value')],
)
def update_residuals_graph(col, ratio, pos_label, tab):
self.explainer.pos_label = pos_label #needed in case of multiple workers
return self.explainer.plot_residuals_vs_feature(col, ratio=ratio, dropna=True)


Expand Down
2 changes: 1 addition & 1 deletion explainerdashboard/dashboard_tabs/shadow_trees_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def shadow_trees_layout(explainer, round=2, **kwargs):


def shadow_trees_callbacks(explainer, app, round=2, **kwargs):

@app.callback(
[Output('tree-basevalue', 'children'),
Output('tree-predictions-table', 'columns'),
Expand Down Expand Up @@ -129,6 +128,7 @@ def display_click_data(clickData, idx, tab):
[State('tabs', 'value')]
)
def update_tree_graph(index, pos_label, tab):
explainer.pos_label=pos_label
if index is not None:
return explainer.plot_trees(index, round=round)
return {}
Expand Down
2 changes: 2 additions & 0 deletions explainerdashboard/dashboard_tabs/shap_dependence_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def shap_dependence_callbacks(explainer, app, **kwargs):
Input('label-store', 'data')],
[State('tabs', 'value')])
def update_dependence_shap_scatter_graph(summary_type, cats, depth, pos_label, tab):
explainer.pos_label = pos_label #needed in case of multiple workers
ctx = dash.callback_context
if ctx.triggered:
if depth is None: depth = 10
Expand Down Expand Up @@ -199,6 +200,7 @@ def set_color_col_dropdown(col, cats):
[State('dependence-col', 'value'),
State('dependence-group-categoricals', 'checked')])
def update_dependence_graph(color_col, idx, pos_label, col, cats):
explainer.pos_label = pos_label #needed in case of multiple workers
if color_col is not None:
return explainer.plot_shap_dependence(
col, color_col, highlight_idx=idx, cats=cats)
Expand Down
2 changes: 2 additions & 0 deletions explainerdashboard/dashboard_tabs/shap_interactions_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def update_col_options(cats, col, tab):
[State('interaction-group-categoricals', 'checked')])
def update_interaction_scatter_graph(summary_type, col, depth, pos_label, cats):
if col is not None:
explainer.pos_label = pos_label #needed in case of multiple workers
if depth is None:
depth = len(explainer.columns_ranked(cats))-1
if summary_type=='aggregate':
Expand Down Expand Up @@ -205,6 +206,7 @@ def display_scatter_click_data(clickdata):
State('interaction-group-categoricals', 'checked')])
def update_dependence_graph(interact_col, index, pos_label, col, cats):
if interact_col is not None:
explainer.pos_label = pos_label #needed in case of multiple workers
return (explainer.plot_shap_interaction_dependence(
col, interact_col, highlight_idx=index, cats=cats),
explainer.plot_shap_interaction_dependence(
Expand Down

0 comments on commit 1c394a9

Please sign in to comment.