Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions bayesflow/diagnostics/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .calibration_error import calibration_error
from .posterior_contraction import posterior_contraction
from .root_mean_squared_error import root_mean_squared_error
from .expected_calibration_error import expected_calibration_error
98 changes: 98 additions & 0 deletions bayesflow/diagnostics/metrics/expected_calibration_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np
from keras import ops
from typing import Sequence, Any, Mapping

from ...utils.exceptions import ShapeError
from sklearn.calibration import calibration_curve


def expected_calibration_error(
estimates: np.ndarray,
targets: np.ndarray,
model_names: Sequence[str] = None,
n_bins: int = 10,
return_probs: bool = False,
) -> Mapping[str, Any]:
"""Estimates the expected calibration error (ECE) of a model comparison network according to [1].

[1] Naeini, M. P., Cooper, G., & Hauskrecht, M. (2015).
Obtaining well calibrated probabilities using bayesian binning.
In Proceedings of the AAAI conference on artificial intelligence (Vol. 29, No. 1).

Notes
-----
Make sure that ``targets`` are **one-hot encoded** classes!

Parameters
----------
estimates : array of shape (num_sim, num_models)
The predicted posterior model probabilities.
targets : array of shape (num_sim, num_models)
The one-hot-encoded true model indices.
model_names : Sequence[str], optional (default = None)
Optional model names to show in the output. By default, models are called "M_" + model index.
n_bins : int, optional, default: 10
The number of bins to use for the calibration curves (and marginal histograms).
Passed into ``sklearn.calibration.calibration_curve()``.
return_probs : bool (default = False)
Do you want to obtain the output of ``sklearn.calibration.calibration_curve()``?

Returns
-------
result : dict
Dictionary containing:
- "values" : np.ndarray
The expected calibration error per model
- "metric_name" : str
The name of the metric ("Expected Calibration Error").
- "model_names" : str
The (inferred) variable names.
- "probs_true": (optional) list[np.ndarray]:
Outputs of ``sklearn.calibration.calibration_curve()`` per model
- "probs_pred": (optional) list[np.ndarray]:
Outputs of ``sklearn.calibration.calibration_curve()`` per model
"""

# Convert tensors to numpy, if passed
estimates = ops.convert_to_numpy(estimates)
targets = ops.convert_to_numpy(targets)

if estimates.shape != targets.shape:
raise ShapeError("`estimates` and `targets` must have the same shape.")

if model_names is None:
model_names = ["M_" + str(i) for i in range(estimates.shape[-1])]
elif len(model_names) != estimates.shape[-1]:
raise ShapeError("There must be exactly one `model_name` for each model in `estimates`")

# Extract number of models and prepare containers
ece = []
probs_true = []
probs_pred = []

targets = targets.argmax(axis=-1)

# Loop for each model and compute calibration errs per bin
for model_index in range(estimates.shape[-1]):
y_true = (targets == model_index).astype(np.float32)
y_prob = estimates[..., model_index]
prob_true, prob_pred = calibration_curve(y_true, y_prob, n_bins=n_bins)

# Compute ECE by weighting bin errors by bin size
bins = np.linspace(0.0, 1.0, n_bins + 1)
binids = np.searchsorted(bins[1:-1], y_prob)
bin_total = np.bincount(binids, minlength=len(bins))
nonzero = bin_total != 0
error = np.sum(np.abs(prob_true - prob_pred) * (bin_total[nonzero] / len(y_true)))

ece.append(error)
probs_true.append(prob_true)
probs_pred.append(prob_pred)

output = dict(values=np.array(ece), metric_name="Expected Calibration Error", model_names=model_names)

if return_probs:
output["probs_true"] = probs_true
output["probs_pred"] = probs_pred

return output
23 changes: 14 additions & 9 deletions bayesflow/diagnostics/plots/mc_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@


from bayesflow.utils import (
expected_calibration_error,
prepare_plot_data,
add_titles_and_labels,
add_metric,
prettify_subplots,
)

from bayesflow.diagnostics.metrics import expected_calibration_error


def mc_calibration(
pred_models: dict[str, np.ndarray] | np.ndarray,
true_models: dict[str, np.ndarray] | np.ndarray,
model_names: Sequence[str] = None,
num_bins: int = 10,
n_bins: int = 10,
label_fontsize: int = 16,
title_fontsize: int = 18,
metric_fontsize: int = 14,
Expand All @@ -40,7 +41,7 @@ def mc_calibration(
The one-hot-encoded true model indices per data set.
model_names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
num_bins : int, optional, default: 10
n_bins : int, optional, default: 10
The number of bins to use for the calibration curves (and marginal histograms).
label_fontsize : int, optional, default: 16
The font size of the y-label and y-label texts
Expand Down Expand Up @@ -77,17 +78,21 @@ def mc_calibration(
default_name="M",
)

# Compute calibration
cal_errors, true_probs, pred_probs = expected_calibration_error(
plot_data["targets"], plot_data["estimates"], num_bins
# compute ece and probs
ece = expected_calibration_error(
estimates=pred_models,
targets=true_models,
model_names=plot_data["variable_names"],
n_bins=n_bins,
return_probs=True,
)

for j, ax in enumerate(plot_data["axes"].flat):
# Plot calibration curve
ax.plot(pred_probs[j], true_probs[j], "o-", color=color)
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)

# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
uniform_bins = np.linspace(0.0, 1.0, n_bins + 1)
norm_weights = np.ones_like(plot_data["estimates"]) / len(plot_data["estimates"])
ax.hist(plot_data["estimates"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3)

Expand All @@ -104,7 +109,7 @@ def mc_calibration(
add_metric(
ax,
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$",
metric_value=cal_errors[j],
metric_value=ece["values"][j],
metric_fontsize=metric_fontsize,
)

Expand Down
1 change: 0 additions & 1 deletion bayesflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
numpy_utils,
)
from .callbacks import detailed_loss_callback
from .comp_utils import expected_calibration_error
from .devices import devices
from .dict_utils import (
convert_args,
Expand Down
62 changes: 0 additions & 62 deletions bayesflow/utils/comp_utils.py

This file was deleted.

20 changes: 20 additions & 0 deletions tests/test_diagnostics/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
import pytest
from bayesflow.utils.numpy_utils import softmax


@pytest.fixture()
Expand Down Expand Up @@ -31,3 +32,22 @@ def random_priors():
"sigma": np.random.standard_normal(size=(64, 1)),
"y": np.random.standard_normal(size=(64, 3, 1)),
}


@pytest.fixture()
def model_names():
return [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"]


@pytest.fixture()
def true_models():
true_models = np.random.choice(3, 100)
true_models = np.eye(3)[true_models].astype(np.int32)
return true_models


@pytest.fixture()
def pred_models(true_models):
pred_models = np.random.normal(loc=true_models)
pred_models = softmax(pred_models, axis=-1)
return pred_models
24 changes: 24 additions & 0 deletions tests/test_diagnostics/test_diagnostics_metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import bayesflow as bf
import pytest


def num_variables(x: dict):
Expand Down Expand Up @@ -47,3 +48,26 @@ def test_root_mean_squared_error(random_estimates, random_targets):
assert out["values"].shape == (num_variables(random_estimates),)
assert out["metric_name"] == "NRMSE"
assert out["variable_names"] == ["beta_0", "beta_1", "sigma"]


def test_expected_calibration_error(pred_models, true_models, model_names):
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=model_names)
assert list(out.keys()) == ["values", "metric_name", "model_names"]
assert out["values"].shape == (pred_models.shape[-1],)
assert out["metric_name"] == "Expected Calibration Error"
assert out["model_names"] == [r"$\mathcal{M}_0$", r"$\mathcal{M}_1$", r"$\mathcal{M}_2$"]

# returns probs?
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, return_probs=True)
assert list(out.keys()) == ["values", "metric_name", "model_names", "probs_true", "probs_pred"]
assert len(out["probs_true"]) == pred_models.shape[-1]
assert len(out["probs_pred"]) == pred_models.shape[-1]
# default: auto model names
assert out["model_names"] == ["M_0", "M_1", "M_2"]

# handles incorrect input?
with pytest.raises(Exception):
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models, model_names=["a"])

with pytest.raises(Exception):
out = bf.diagnostics.metrics.expected_calibration_error(pred_models, true_models.transpose)
15 changes: 15 additions & 0 deletions tests/test_diagnostics/test_diagnostics_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,18 @@ def test_pairs_posterior(random_estimates, random_targets, random_priors):
priors=random_priors,
dataset_id=[1, 3],
)


def test_mc_calibration(pred_models, true_models, model_names):
out = bf.diagnostics.plots.mc_calibration(pred_models, true_models, model_names=model_names)
assert len(out.axes) == pred_models.shape[-1]
assert out.axes[0].get_ylabel() == "True Probability"
assert out.axes[0].get_xlabel() == "Predicted Probability"
assert out.axes[-1].get_title() == r"$\mathcal{M}_2$"


def test_mc_confusion_matrix(pred_models, true_models, model_names):
out = bf.diagnostics.plots.mc_confusion_matrix(pred_models, true_models, model_names, normalize="true")
assert out.axes[0].get_ylabel() == "True model"
assert out.axes[0].get_xlabel() == "Predicted model"
assert out.axes[0].get_title() == "Confusion Matrix"