Skip to content

Commit

Permalink
fix precision rounding bug with empty bins
Browse files Browse the repository at this point in the history
  • Loading branch information
oegedijk committed Jun 17, 2020
1 parent 11d31fc commit 036e6c2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
12 changes: 6 additions & 6 deletions explainerdashboard/explainer_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,13 +404,13 @@ def get_precision_df(pred_probas, y_true, bin_size=None, quantiles=None,
new_row_df = pd.DataFrame(new_row_dict, columns=precision_df.columns)
precision_df = pd.concat([precision_df, new_row_df])
last_p_max = preds.max()
precision_df['p_avg'] = np.round(precision_df['p_avg'], round)
precision_df['precision'] = np.round(precision_df['precision'], round)

precision_df[['p_avg', 'precision']] = precision_df[['p_avg', 'precision']]\
.astype(float).apply(partial(np.round, decimals=round))
if n_classes > 1:
for i in range(n_classes):
precision_df['precision_' + str(i)] = \
np.round(precision_df['precision_' + str(i)], round)
precision_cols = ['precision_' + str(i) for i in range(n_classes)]
precision_df[precision_cols] = precision_df[precision_cols]\
.astype(float).apply(partial(np.round, decimals=round))
return precision_df


Expand Down
7 changes: 4 additions & 3 deletions tests/test_classifier_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def setUp(self):
model = RandomForestClassifier(n_estimators=5, max_depth=2)
model.fit(X_train, y_train)

self.explainer = ClassifierExplainer(
model, X_test, y_test, roc_auc_score,
shap='tree',
self.explainer = ClassifierExplainer(model, X_test, y_test,
cats=['Sex', 'Cabin', 'Embarked'],
idxs=test_names,
labels=['Not survived', 'Survived'])
Expand Down Expand Up @@ -69,6 +67,9 @@ def test_plot_precision(self):
fig = self.explainer.plot_precision(multiclass=True)
self.assertIsInstance(fig, go.Figure)

fig = self.explainer.plot_precision(quantiles=10, cutoff=0.5)
self.assertIsInstance(fig, go.Figure)

def test_plot_cumulutive_precision(self):
fig = self.explainer.plot_cumulative_precision()
self.assertIsInstance(fig, go.Figure)
Expand Down

0 comments on commit 036e6c2

Please sign in to comment.