Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,6 +1060,8 @@ def plot_calibration_curves(
# Determine n_subplots dynamically
n_row = int(np.ceil(num_models / 6))
n_col = int(np.ceil(num_models / n_row))

# Compute calibration
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)

# Initialize figure
Expand Down Expand Up @@ -1094,8 +1096,6 @@ def plot_calibration_curves(
ax[j].spines["top"].set_visible(False)
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
ax[j].set_xlabel("Predicted probability", fontsize=label_fontsize)
ax[j].set_ylabel("True probability", fontsize=label_fontsize)
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax[j].grid(alpha=0.5)
Expand All @@ -1111,6 +1111,18 @@ def plot_calibration_curves(
size=legend_fontsize,
)

# Only add x-labels to the bottom row
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
for _ax in bottom_row:
_ax.set_xlabel("Predicted probability", fontsize=label_fontsize)

# Only add y-labels to left-most row
if n_row == 1: # if there is only one row, the ax array is 1D
ax[0].set_ylabel("True probability", fontsize=label_fontsize)
else: # if there is more than one row, the ax array is 2D
for _ax in axarr[:, 0]:
_ax.set_ylabel("True probability", fontsize=label_fontsize)

fig.tight_layout()
return fig

Expand Down