Skip to content

Commit

Permalink
fix calibration score check (#325)
Browse files Browse the repository at this point in the history
  • Loading branch information
benisraeldan committed Dec 27, 2021
1 parent c4af9e8 commit eb5fbbb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
14 changes: 7 additions & 7 deletions deepchecks/checks/performance/calibration_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,15 @@ def _calibration_score(self, dataset: Dataset, model):
y_pred = model.predict_proba(ds_x)

briers_scores = {}
unique_labels = dataset.label_col.unique()
unique_labels = model.classes_

if len(unique_labels) == 2:
briers_scores[0] = brier_score_loss(ds_y, y_pred[:, 1])
else:
for n_class in unique_labels:
for n_class, class_name in enumerate(unique_labels):
prob_pos = y_pred[:, n_class]
clf_score = brier_score_loss(ds_y == n_class, prob_pos, pos_label=n_class)
briers_scores[n_class] = clf_score
clf_score = brier_score_loss(ds_y == class_name, prob_pos, pos_label=class_name)
briers_scores[class_name] = clf_score

fig = go.Figure()

Expand All @@ -78,17 +78,17 @@ def _calibration_score(self, dataset: Dataset, model):
name=f'(brier:{briers_scores[0]:9.4f})',
))
else:
for n_class in unique_labels:
for n_class, class_name in enumerate(unique_labels):
prob_pos = y_pred[:, n_class]

fraction_of_positives, mean_predicted_value = \
calibration_curve(ds_y == n_class, prob_pos, n_bins=10)
calibration_curve(ds_y == class_name, prob_pos, n_bins=10)

fig.add_trace(go.Scatter(
x=mean_predicted_value,
y=fraction_of_positives,
mode='lines+markers',
name=f'{n_class} (brier:{briers_scores[n_class]:9.4f})',
name=f'{class_name} (brier:{briers_scores[class_name]:9.4f})',
))

fig.update_layout(title_text='Calibration plots (reliability curve)',
Expand Down
28 changes: 28 additions & 0 deletions tests/checks/performance/calibration_score_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,31 @@ def test_model_info_object(iris_labeled_dataset, iris_adaboost):
1: close_to(0.002, 0.05),
2: close_to(0.28, 0.05)
}))


def test_model_info_object(iris_labeled_dataset, iris_adaboost):
# Arrange
check = CalibrationScore()
# Act X
result = check.run(iris_labeled_dataset, iris_adaboost).value
# Assert
assert len(result) == 3 # iris has 3 targets

assert_that(result, has_entries({
0: close_to(0.99, 0.05),
1: close_to(0.002, 0.05),
2: close_to(0.28, 0.05)
}))


def test_binary_model_info_object(iris_dataset_single_class_labeled, iris_random_forest_single_class):
# Arrange
check = CalibrationScore()
# Act X
result = check.run(iris_dataset_single_class_labeled, iris_random_forest_single_class).value
# Assert
assert len(result) == 1 # iris has 3 targets

assert_that(result, has_entries({
0: close_to(0.0002, 0.0005)
}))

0 comments on commit eb5fbbb

Please sign in to comment.