diff --git a/bayesflow/diagnostics.py b/bayesflow/diagnostics.py index 94cb4d882..891be2c52 100644 --- a/bayesflow/diagnostics.py +++ b/bayesflow/diagnostics.py @@ -1060,6 +1060,8 @@ def plot_calibration_curves( # Determine n_subplots dynamically n_row = int(np.ceil(num_models / 6)) n_col = int(np.ceil(num_models / n_row)) + + # Compute calibration cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins) # Initialize figure @@ -1094,8 +1096,6 @@ def plot_calibration_curves( ax[j].spines["top"].set_visible(False) ax[j].set_xlim([0 - epsilon, 1 + epsilon]) ax[j].set_ylim([0 - epsilon, 1 + epsilon]) - ax[j].set_xlabel("Predicted probability", fontsize=label_fontsize) - ax[j].set_ylabel("True probability", fontsize=label_fontsize) ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0]) ax[j].grid(alpha=0.5) @@ -1111,6 +1111,18 @@ def plot_calibration_curves( size=legend_fontsize, ) + # Only add x-labels to the bottom row + bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :] + for _ax in bottom_row: + _ax.set_xlabel("Predicted probability", fontsize=label_fontsize) + + # Only add y-labels to left-most row + if n_row == 1: # if there is only one row, the ax array is 1D + ax[0].set_ylabel("True probability", fontsize=label_fontsize) + else: # if there is more than one row, the ax array is 2D + for _ax in axarr[:, 0]: + _ax.set_ylabel("True probability", fontsize=label_fontsize) + fig.tight_layout() return fig