Skip to content

Commit

Permalink
update shape precision component
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Jan 31, 2021
1 parent 6547fcb commit 2895948
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 29 deletions.
7 changes: 6 additions & 1 deletion RELEASE_NOTES.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Release Notes

## Version 0.3.1:

This version is mostly about pre-calculating and optimizing the classifier statistics
components. Those components should now be much more responsive with large datasets.

### New Features
- new methods `roc_auc_curve(pos_label)` and `pr_auc_curve(pos_label)`
Expand Down Expand Up @@ -29,7 +30,11 @@
- dashboard should be more responsive for large datasets
- pre-calculating confusion matrices
- dashboard should be more responsive for large datasets
- pre-calculating classification_dfs
- dashboard should be more responsive for large datasets
- confusion matrix: added axis title, moved predicted labels to bottom of graph
- precision plot: when only adjusting cutoff, simply updating the cutoff
line, without recalculating the plot.

### Other Changes
-
Expand Down
15 changes: 13 additions & 2 deletions explainerdashboard/dashboard_components/classifier_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ def __init__(self, explainer, title="Prediction", name=None,
Shows the predicted probability for each {self.explainer.target} label.
"""

self.register_dependencies("metrics")

def layout(self):
return dbc.Card([
make_hideable(
Expand Down Expand Up @@ -590,9 +592,18 @@ def update_div_visibility(bins_or_quantiles):
Input('precision-cutoff-'+self.name, 'value'),
Input('precision-multiclass-'+self.name, 'value'),
Input('pos-label-'+self.name, 'value')],
#[State('tabs', 'value')],
[State('precision-graph-'+self.name, 'figure')],
)
def update_precision_graph(bin_size, quantiles, bins, cutoff, multiclass, pos_label):
def update_precision_graph(bin_size, quantiles, bins, cutoff, multiclass, pos_label, fig):
ctx = dash.callback_context
trigger = ctx.triggered[0]['prop_id'].split('.')[0]
if trigger == 'precision-cutoff-'+self.name and fig is not None:
return go.Figure(fig).update_shapes(dict(
type='line',
xref='x', yref='y2',
x0=cutoff, x1=cutoff,
y0=0, y1=1.0,
))
if bins == 'bin_size':
return self.explainer.plot_precision(
bin_size=bin_size, cutoff=cutoff,
Expand Down
84 changes: 58 additions & 26 deletions explainerdashboard/explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2068,19 +2068,35 @@ def metrics(self, cutoff=0.5, pos_label=None):
"""
if self.y_missing:
raise ValueError("No y was passed to explainer, so cannot calculate metrics!")
y_true = self.y_binary(pos_label)
y_pred = np.where(self.pred_probas(pos_label) > cutoff, 1, 0)

metrics_dict = {
'accuracy' : accuracy_score(y_true, y_pred),
'precision' : precision_score(y_true, y_pred),
'recall' : recall_score(y_true, y_pred),
'f1' : f1_score(y_true, y_pred),
'roc_auc_score' : roc_auc_score(y_true, self.pred_probas(pos_label)),
'pr_auc_score' : average_precision_score(y_true, self.pred_probas(pos_label)),
'log_loss' : log_loss(y_true, self.pred_probas(pos_label))
}
return metrics_dict
def get_metrics(cutoff, pos_label):
y_true = self.y_binary(pos_label)
y_pred = np.where(self.pred_probas(pos_label) > cutoff, 1, 0)

metrics_dict = {
'accuracy' : accuracy_score(y_true, y_pred),
'precision' : precision_score(y_true, y_pred, zero_division=0),
'recall' : recall_score(y_true, y_pred),
'f1' : f1_score(y_true, y_pred),
'roc_auc_score' : roc_auc_score(y_true, self.pred_probas(pos_label)),
'pr_auc_score' : average_precision_score(y_true, self.pred_probas(pos_label)),
'log_loss' : log_loss(y_true, self.pred_probas(pos_label))
}
return metrics_dict

if not hasattr(self, "_metrics"):
_ = self.pred_probas()
print("Calculating metrics...", flush=True)
self._metrics = dict()
for label in range(len(self.labels)):
self._metrics[label] = dict()
for cut in np.linspace(0.01, 0.99, 99):
self._metrics[label][np.round(cut, 2)] = \
get_metrics(cut, label)
if cutoff in self._metrics[pos_label]:
return self._metrics[pos_label][cutoff]
else:
return get_metrics(cutoff, pos_label)

@insert_pos_label
def metrics_descriptions(self, cutoff=0.5, round=3, pos_label=None):
Expand Down Expand Up @@ -2253,27 +2269,41 @@ def get_liftcurve_df(self, pos_label=None):
return self._liftcurve_dfs[pos_label]

@insert_pos_label
def get_classification_df(self, cutoff=0.5, percentage=False, pos_label=None):
def get_classification_df(self, cutoff=0.5, pos_label=None):
"""Returns a dataframe with number of observations in each class above
and below the cutoff.
Args:
cutoff (float, optional): Cutoff to split on. Defaults to 0.5.
percentage (bool, optional): Normalize results. Defaults to False.
pos_label (int, optional): Pos label to generate dataframe for.
Defaults to self.pos_label.
Returns:
pd.DataFrame
"""
clas_df = pd.DataFrame(index=pd.RangeIndex(0, len(self.labels)))
clas_df['below'] = self.y[self.pred_probas(pos_label) < cutoff].value_counts(normalize=percentage)
clas_df['above'] = self.y[self.pred_probas(pos_label) >= cutoff].value_counts(normalize=percentage)
clas_df = clas_df.fillna(0)
clas_df['total'] = clas_df.sum(axis=1)
clas_df.index = self.labels
return clas_df

def get_clas_df(cutoff, pos_label):
clas_df = pd.DataFrame(index=pd.RangeIndex(0, len(self.labels)))
clas_df['below'] = self.y[self.pred_probas(pos_label) < cutoff].value_counts()
clas_df['above'] = self.y[self.pred_probas(pos_label) >= cutoff].value_counts()
clas_df = clas_df.fillna(0)
clas_df['total'] = clas_df.sum(axis=1)
clas_df.index = self.labels
return clas_df

if not hasattr(self, "_classification_dfs"):
_ = self.pred_probas()
print("Calculating classification_dfs...", flush=True)
self._classification_dfs = dict()
for label in range(len(self.labels)):
self._classification_dfs[label] = dict()
for cut in np.linspace(0.01, 0.99, 99):
self._classification_dfs[label][np.round(cut, 2)] = \
get_clas_df(cut, label)
if cutoff in self._classification_dfs[pos_label]:
return self._classification_dfs[pos_label][cutoff]
else:
return get_clas_df(cutoff, pos_label)

@insert_pos_label
def roc_auc_curve(self, pos_label=None):
"""Returns a dict with output from sklearn.metrics.roc_curve() for pos_label:
Expand Down Expand Up @@ -2314,9 +2344,9 @@ def get_binary_cm(y, pred_probas, cutoff, pos_label):
self._confusion_matrices['binary'] = dict()
for label in range(len(self.labels)):
self._confusion_matrices['binary'][label] = dict()
for cutoff in np.linspace(0.01, 0.99, 99):
self._confusion_matrices['binary'][label][np.round(cutoff, 2)] = \
get_binary_cm(self.y, self.pred_probas_raw, cutoff, label)
for cut in np.linspace(0.01, 0.99, 99):
self._confusion_matrices['binary'][label][np.round(cut, 2)] = \
get_binary_cm(self.y, self.pred_probas_raw, cut, label)
self._confusion_matrices['multi'] = confusion_matrix(self.y, self.pred_probas_raw.argmax(axis=1))
if binary:
if cutoff in self._confusion_matrices['binary'][pos_label]:
Expand Down Expand Up @@ -2447,7 +2477,9 @@ def plot_classification(self, cutoff=0.5, percentage=True, pos_label=None):
plotly fig
"""
return plotly_classification_plot(self.get_classification_df(cutoff=cutoff, pos_label=pos_label), percentage=percentage)
return plotly_classification_plot(
self.get_classification_df(cutoff=cutoff, pos_label=pos_label),
percentage=percentage)

@insert_pos_label
def plot_roc_auc(self, cutoff=0.5, pos_label=None):
Expand Down

0 comments on commit 2895948

Please sign in to comment.