Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def calibration_ecdf(
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
test_quantities: dict[str, Callable] = None,
difference: bool = False,
difference: bool = True,
stacked: bool = False,
rank_type: str | np.ndarray = "fractional",
figsize: Sequence[float] = None,
label_fontsize: int = 16,
legend_fontsize: int = 14,
legend_location: str = "upper right",
legend_location: str = "lower right",
title_fontsize: int = 18,
tick_fontsize: int = 12,
rank_ecdf_color: str = "#132a70",
Expand Down Expand Up @@ -59,7 +59,7 @@ def calibration_ecdf(
The posterior draws obtained from n_data_sets
targets : np.ndarray of shape (n_data_sets, n_params)
The prior draws obtained for generating n_data_sets
difference : bool, optional, default: False
difference : bool, optional, default: True
If `True`, plots the ECDF difference.
Enables a more dynamic visualization range.
stacked : bool, optional, default: False
Expand Down Expand Up @@ -98,7 +98,9 @@ def calibration_ecdf(
label_fontsize : int, optional, default: 16
The font size of the y-label and y-label texts
legend_fontsize : int, optional, default: 14
The font size of the legend text
The font size of the legend text.
legend_location : str, optional, default: 'lower right
The location of the legend.
title_fontsize : int, optional, default: 18
The font size of the title text.
Only relevant if `stacked=False`
Expand Down Expand Up @@ -211,11 +213,13 @@ def calibration_ecdf(
else:
titles = ["Stacked ECDFs"]

for ax, title in zip(plot_data["axes"].flat, titles):
for i, (ax, title) in enumerate(zip(plot_data["axes"].flat, titles)):
ax.fill_between(z, L, U, color=fill_color, alpha=0.2, label=rf"{int((1 - alpha) * 100)}$\%$ Confidence Bands")
ax.legend(fontsize=legend_fontsize, loc=legend_location)
ax.set_title(title, fontsize=title_fontsize)

if i == 0:
ax.legend(fontsize=legend_fontsize, loc=legend_location)

prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

add_titles_and_labels(
Expand Down
10 changes: 6 additions & 4 deletions bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ def calibration_ecdf_from_quantiles(
quantiles_key: str = "quantiles",
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
difference: bool = False,
difference: bool = True,
stacked: bool = False,
figsize: Sequence[float] = None,
label_fontsize: int = 16,
legend_fontsize: int = 14,
legend_location: str = "upper right",
legend_location: str = "lower right",
title_fontsize: int = 18,
tick_fontsize: int = 12,
rank_ecdf_color: str = "#132a70",
Expand Down Expand Up @@ -69,7 +69,7 @@ def calibration_ecdf_from_quantiles(
variable_names : list or None, optional, default: None
The parameter names for nice plot titles.
Inferred if None. Only relevant if `stacked=False`.
difference : bool, optional, default: False
difference : bool, optional, default: True
If `True`, plots the ECDF difference.
Enables a more dynamic visualization range.
stacked : bool, optional, default: False
Expand All @@ -82,7 +82,9 @@ def calibration_ecdf_from_quantiles(
label_fontsize : int, optional, default: 16
The font size of the y-label and y-label texts
legend_fontsize : int, optional, default: 14
The font size of the legend text
The font size of the legend text.
legend_location : str, optional, default: 'lower right
The location of the legend.
title_fontsize : int, optional, default: 18
The font size of the title text.
Only relevant if `stacked=False`
Expand Down
30 changes: 13 additions & 17 deletions bayesflow/diagnostics/plots/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ def coverage(
variable_names: Sequence[str] = None,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
legend_fontsize: int = 14,
title_fontsize: int = 18,
tick_fontsize: int = 12,
legend_location: str = "lower right",
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
Expand All @@ -39,7 +41,7 @@ def coverage(
The posterior draws obtained from num_datasets
targets : np.ndarray of shape (num_datasets, num_params)
The true parameter values used for generating num_datasets
difference : bool, optional, default: False
difference : bool, optional, default: True
If True, plots the difference between empirical coverage and ideal coverage
(coverage - width), making deviations from ideal calibration more visible.
If False, plots the standard coverage plot.
Expand All @@ -52,10 +54,14 @@ def coverage(
The figure size passed to the matplotlib constructor. Inferred if None.
label_fontsize : int, optional, default: 16
The font size of the y-label and x-label text
legend_fontsize : int, optional, default: 14
The font size of the legend text
title_fontsize : int, optional, default: 18
The font size of the title text
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
legend_location : str, optional, default: 'upper right
The location of the legend.
color : str, optional, default: '#132a70'
The color for the coverage line
num_row : int, optional, default: None
Expand Down Expand Up @@ -128,17 +134,11 @@ def coverage(
)

# Plot ideal coverage difference line (y = 0)
ax.axhline(y=0, color="skyblue", linewidth=2.0, label="Ideal Coverage")
ax.axhline(y=0, color="black", linestyle="dashed", label="Ideal Coverage")

# Plot empirical coverage difference
ax.plot(width_rep, diff_est, color=color, alpha=1.0, label="Coverage Difference")

# Set axis limits
ax.set_xlim(0, 1)

# Add legend to first subplot
if i == 0:
ax.legend(fontsize=tick_fontsize, loc="upper right")
else:
# Plot confidence ribbon
ax.fill_between(
Expand All @@ -151,23 +151,19 @@ def coverage(
)

# Plot ideal coverage line (y = x)
ax.plot([0, 1], [0, 1], color="skyblue", linewidth=2.0, label="Ideal Coverage")
ax.plot([0, 1], [0, 1], color="black", linestyle="dashed", label="Ideal Coverage")

# Plot empirical coverage
ax.plot(width_rep, coverage_est, color=color, alpha=1.0, label="Empirical Coverage")

# Set axis limits
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Add legend to first subplot
if i == 0:
ax.legend(fontsize=tick_fontsize, loc="upper left")
# Add legend to first subplot
if i == 0:
ax.legend(fontsize=legend_fontsize, loc=legend_location)

prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

# Add labels, titles, and set font sizes
ylabel = "Observed coverage difference" if difference else "Observed coverage"
ylabel = "Empirical coverage difference" if difference else "Empirical coverage"
add_titles_and_labels(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
Expand Down
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plots/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def loss(
num_col=1,
title=["Loss Trajectory"],
xlabel="Training epoch #",
ylabel="Value",
ylabel="Loss",
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
)
Expand Down
6 changes: 4 additions & 2 deletions bayesflow/diagnostics/plots/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@ def recovery(
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
xlabel :
ylabel :
xlabel : str, optional, default: "Ground truth"
The label shown on the x-axis.
ylabel : str, optional, default: "Estimate"
The label shown on the y-axis.
markersize : float, optional, default: None
The marker size in points.

Expand Down
29 changes: 22 additions & 7 deletions bayesflow/workflows/basic_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,8 +349,12 @@ def plot_default_diagnostics(
- Loss history (if training history is available).
- Parameter recovery plots.
- Calibration ECDF plots.
- Coverage plots.
- Z-score contraction plots.

Caution: For models with many parameters, plotting all marginal diagnostics becomes unwieldy. Consider
providing `variables_keys` for visualizing the diagnostics for subsets of the parameter space.

Parameters
----------
test_data : Mapping[str, np.ndarray] or int
Expand Down Expand Up @@ -400,6 +404,7 @@ def plot_default_diagnostics(
plot_fns = {
"recovery": bf_plots.recovery,
"calibration_ecdf": bf_plots.calibration_ecdf,
"coverage": bf_plots.coverage,
"z_score_contraction": bf_plots.z_score_contraction,
}

Expand Down Expand Up @@ -499,9 +504,10 @@ def compute_default_diagnostics(
"""
Computes default diagnostic metrics to evaluate the quality of inference. The function computes several
diagnostic metrics, including:
- Root Mean Squared Error (RMSE)
- Posterior contraction
- Calibration error
- (Normalized) Root Mean Squared Error ((N)RMSE): summarizes the recovery plots
- Log-gamma statistic - summarizes the ECDF calibration plots
- Expected Calibration Error (ECE) - summarizes the coverage plots
- Posterior contraction - partially summarizes the contraction plots

Parameters
----------
Expand Down Expand Up @@ -553,12 +559,12 @@ def compute_default_diagnostics(
**kwargs.get("root_mean_squared_error_kwargs", {}),
)

contraction = bf_metrics.posterior_contraction(
log_gamma = bf_metrics.calibration_log_gamma(
estimates=samples,
targets=test_data,
variable_keys=variable_keys,
variable_names=variable_names,
**kwargs.get("posterior_contraction_kwargs", {}),
**kwargs.get("log_gamma_kwargs", {}),
)

calibration_errors = bf_metrics.calibration_error(
Expand All @@ -569,17 +575,26 @@ def compute_default_diagnostics(
**kwargs.get("calibration_error_kwargs", {}),
)

contraction = bf_metrics.posterior_contraction(
estimates=samples,
targets=test_data,
variable_keys=variable_keys,
variable_names=variable_names,
**kwargs.get("posterior_contraction_kwargs", {}),
)

if as_data_frame:
metrics = pd.DataFrame(
{
root_mean_squared_error["metric_name"]: root_mean_squared_error["values"],
contraction["metric_name"]: contraction["values"],
log_gamma["metric_name"]: log_gamma["values"],
calibration_errors["metric_name"]: calibration_errors["values"],
contraction["metric_name"]: contraction["values"],
},
index=variable_keys or root_mean_squared_error["variable_names"],
).T
else:
metrics = (root_mean_squared_error, contraction, calibration_errors)
metrics = (root_mean_squared_error, log_gamma, calibration_errors, contraction)

return metrics

Expand Down
2 changes: 1 addition & 1 deletion examples/Multimodal_Data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": null,
"id": "2415fd0b-f5d6-4fc9-83d7-8952e6270186",
"metadata": {},
"outputs": [
Expand Down
Loading