In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import tqdm
import pickle
# from stylesheets.register_roboto import register_roboto
# register_roboto()

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import matplotlib.patches as patches
from matplotlib.lines import Line2D
from matplotlib.legend_handler import HandlerLine2D, HandlerPatch
from matplotlib.patches import FancyArrowPatch
import seaborn as sns
import numpy as np
import pandas as pd
import torch
from torch.distributions import Uniform, Normal
from sbi.inference import FMPE, SNPE

from lf2i.inference import LF2I
from lf2i.test_statistics.posterior import Posterior
from lf2i.plot.coverage_diagnostics import coverage_probability_plot
from lf2i.utils.other_methods import hpd_region
from lf2i.utils.calibration_diagnostics_inputs import preprocess_predict_p_values

from utils import GaussianMean, DividedPatchHandler

In [None]:
B = 100_000
B_PRIME = EVAL_GRID_SIZE = 50_000
B_DOUBLE_PRIME = 20_000
CONFIDENCE_LEVEL = 0.9, # 0.95, 0.683

NORM_POSTERIOR_SAMPLES = None

PRIOR = Normal(loc=torch.Tensor([0]), scale=torch.Tensor([1]))
REFERENCE = Uniform(low=torch.Tensor([-11]), high=torch.Tensor([11]))  # estimate on a slightly larger space to avoid boundary effects
EVAL_GRID_DISTR = Uniform(low=torch.Tensor([-10]), high=torch.Tensor([10]))

In [None]:
simulator = GaussianMean(
    poi_dim=1, 
    prior=PRIOR,
    reference=REFERENCE
)

### ESTIMATION

In [None]:
try:
    with open("results/results_revised/schema_b", "rb") as f:
        schema_b = pickle.load(f)
        b_params = schema_b["params"]
        b_samples = schema_b["samples"]
except FileNotFoundError:
    b_params, b_samples = simulator.simulate_for_test_statistic(size=B)
    # _addl_b_params, _addl_b_samples = simulator.simulate_for_critical_values(size=int(B*0.05))
    # b_params = torch.cat([b_params, _addl_b_params], dim=0)
    # b_samples = torch.cat([b_samples, _addl_b_samples], dim=0)
    with open("results/results_revised/schema_b", "wb") as f:
        pickle.dump({"params": b_params, "samples": b_samples}, f)

b_params.shape, b_samples.shape

In [None]:
plt.hist(b_params, bins=100, density=True, alpha=0.5, color="blue", label="b_params")
plt.title("Histogram of b_params")
plt.show()

In [None]:
plt.hist2d(b_params.squeeze(), b_samples.squeeze(), bins=100, cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel('b_params')
plt.ylabel('b_samples')
plt.show()

In [None]:
try:
    with open('results/results_revised/pstr_schema.pkl', 'rb') as f:
        posterior = pickle.load(f)
except FileNotFoundError:
    estimator = SNPE(
        prior=simulator.prior,
        density_estimator='maf',
        device='cpu'
    )

    _ = estimator.append_simulations(b_params, b_samples.reshape(-1, 1)).train()
    posterior = estimator.build_posterior()
    with open('results/results_revised/pstr_schema.pkl', 'wb') as f:
        pickle.dump(posterior, f)

In [None]:
try:
    with open('results/results_revised/pstr_schema_samples.pkl', 'rb') as f:
        pster_schema_samples = pickle.load(f)
        obs_theta = pster_schema_samples['obs_theta']
        obs_x = pster_schema_samples['obs_x']
        prior_samples = pster_schema_samples['prior_samples']
        posterior_samples = pster_schema_samples['posterior_samples']
except FileNotFoundError:
    obs_theta = torch.Tensor([[4]])
    obs_x = torch.tensor([[3.7]]) # simulator.likelihood(obs_theta).sample(sample_shape=(1, )).squeeze(0)
    prior_samples = simulator.prior.sample(sample_shape=(100_000, ))
    posterior_samples = posterior.sample(sample_shape=(100_000, ), x=obs_x.reshape(1, ))
    with open('results/results_revised/pstr_schema_samples.pkl', 'wb') as f:
        pickle.dump({
            'obs_theta': obs_theta,
            'obs_x': obs_x,
            'prior_samples': prior_samples,
            'posterior_samples': posterior_samples
        }, f)

In [None]:
try:
    with open('results/results_revised/credible_sets_schema.pkl', 'rb') as f:
        credible_sets_schema = pickle.load(f)
        credible_sets = credible_sets_schema['credible_sets']
except FileNotFoundError:
    credible_sets = []
    for cl in CONFIDENCE_LEVEL:
        credible_set = hpd_region(
            posterior=posterior,
            param_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
            x=obs_x.reshape(-1, ),
            credible_level=cl,
            num_level_sets=10_000,
            norm_posterior=NORM_POSTERIOR_SAMPLES
        )[1].numpy()
        credible_sets.append(credible_set)
    with open('results/results_revised/credible_sets_schema.pkl', 'wb') as f:
        pickle.dump({'credible_sets': credible_sets}, f)

In [None]:
# try:
#     with open('results/results_revised/credible_sets_schema_10_obs_x.pkl', 'rb') as f:
#         credible_sets_schema_10_obs_x = pickle.load(f)
#         obs_x_samples = credible_sets_schema_10_obs_x['obs_x_samples']
#         credible_sets_list = credible_sets_schema_10_obs_x['credible_sets_list']
# except FileNotFoundError:
#     # Draw 10 samples of obs_x from the likelihood at obs_theta
#     obs_x_samples = simulator.likelihood(obs_theta).sample(sample_shape=(10, )).squeeze(-1)

#     # Compute credible sets for each obs_x sample
#     credible_sets_list = []
#     for i in range(10):
#         cs = []
#         for cl in CONFIDENCE_LEVEL:
#             cs.append(
#                 hpd_region(
#                     posterior=posterior,
#                     param_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
#                     x=obs_x_samples[i].reshape(1, ),
#                     credible_level=cl,
#                     num_level_sets=10_000,
#                     norm_posterior=NORM_POSTERIOR_SAMPLES
#                 )[1].numpy()
#             )
#         credible_sets_list.append(cs)

#     # Save the credible sets for each sample
#     with open('results/results_revised/credible_sets_10_obs_x.pkl', 'wb') as f:
#         pickle.dump({
#             'obs_x_samples': obs_x_samples,
#             'credible_sets_list': credible_sets_list
#         }, f)

In [None]:
# try:
#     credible_sets_schema_10_obs_x = {}
#     obs_x_samples_dict = {}
#     credible_sets_list_dict = {}
#     thetas = [-8, -4, 0, 4, 8]
#     for theta_val in thetas:
#         try:
#             with open(f'results/results_revised/credible_sets_schema_10_obs_x_theta_{theta_val}.pkl', 'rb') as f:
#                 data = pickle.load(f)
#                 obs_x_samples_dict[theta_val] = data['obs_x_samples']
#                 credible_sets_list_dict[theta_val] = data['credible_sets_list']
#         except FileNotFoundError:
#             obs_theta_val = torch.tensor([[theta_val]], dtype=torch.float32)
#             obs_x_samples = simulator.likelihood(obs_theta_val).sample(sample_shape=(10, )).squeeze(-1)
#             credible_sets_list = []
#             for i in range(10):
#                 cs = []
#                 for cl in CONFIDENCE_LEVEL:
#                     cs.append(
#                         hpd_region(
#                             posterior=posterior,
#                             param_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
#                             x=obs_x_samples[i].reshape(1, ),
#                             credible_level=cl,
#                             num_level_sets=10_000,
#                             norm_posterior=NORM_POSTERIOR_SAMPLES
#                         )[1].numpy()
#                     )
#                 credible_sets_list.append(cs)
#             obs_x_samples_dict[theta_val] = obs_x_samples
#             credible_sets_list_dict[theta_val] = credible_sets_list
#             with open(f'results/results_revised/credible_sets_schema_10_obs_x_theta_{theta_val}.pkl', 'wb') as f:
#                 pickle.dump({
#                     'obs_x_samples': obs_x_samples,
#                     'credible_sets_list': credible_sets_list
#                 }, f)
#     credible_sets_schema_10_obs_x = {
#         'obs_x_samples_dict': obs_x_samples_dict,
#         'credible_sets_list_dict': credible_sets_list_dict
#     }
# except Exception as e:
#     print(f"Error: {e}")

### CALIBRATION

In [None]:
try:
    with open('results/results_revised/schema_b_prime.pkl', 'rb') as f:
        schema_b_prime = pickle.load(f)
        b_prime_params = schema_b_prime["params"]
        b_prime_samples = schema_b_prime["samples"]
# except FileNotFoundError:
#     with open('results/results_revised_old/lf2i_obj_schema.pkl', 'rb') as f:
#         lf2i_old = pickle.load(f)
#     b_prime_params = lf2i_old.parameters_calib.numpy().squeeze()
#     b_prime_samples = lf2i_old.test_statistics_calib
#     with open('results/results_revised/schema_b_prime.pkl', 'wb') as f:
#         pickle.dump({"params": b_prime_params, "samples": b_prime_samples}, f)
except:
    b_prime_params, b_prime_samples = simulator.simulate_for_critical_values(size=B_PRIME)
    with open('results/results_revised/schema_b_prime.pkl', 'wb') as f:
        pickle.dump({"params": b_prime_params, "samples": b_prime_samples}, f)
b_prime_params.shape, b_prime_samples.shape

In [None]:
plt.hist(b_prime_params)
plt.show()

In [None]:
plt.hist2d(b_prime_params.squeeze(), b_prime_samples.squeeze(), bins=100, cmap='viridis')
plt.colorbar(label='Counts')
plt.xlabel('b_prime_params')
plt.ylabel('b_prime_samples')
plt.title('2D Histogram of b_prime_samples vs b_prime_params')
plt.show()

In [None]:
try:
    with open('results/results_revised/confidence_sets_schema.pkl', 'rb') as f:
        confidence_sets = pickle.load(f)
    with open('results/results_revised/lf2i_obj_schema.pkl', 'rb') as f:
        lf2i = pickle.load(f)
except FileNotFoundError:
    try:
        with open('results/results_revised/lf2i_obj_schema.pkl', 'rb') as f:
            lf2i = pickle.load(f)
    except FileNotFoundError:
        lf2i = LF2I(
            test_statistic=Posterior(
                poi_dim=1, estimator=posterior,
            )
        )
    confidence_sets = lf2i.inference(
        x=obs_x,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='p-values',
        calibration_model='cat-gb',
        calibration_model_kwargs={'iterations': 1000, 'depth': 9},
        # calibration_model_kwargs={
        #     'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
        #     'n_iter': 25
        # },
        T_prime=(b_prime_params, b_prime_samples),
        num_augment=10
    )
    with open('results/results_revised/confidence_sets_schema.pkl', 'wb') as f:
        pickle.dump(confidence_sets, f)
    try:
        with open('results/results_revised/lf2i_obj_schema.pkl', 'rb') as f:
            lf2i = pickle.load(f)
    except FileNotFoundError:
        with open('results/results_revised/lf2i_obj_schema.pkl', 'wb') as f:
            pickle.dump(lf2i, f)

In [None]:
# try:
#     confidence_sets_schema_10_obs_x = {}
#     obs_x_samples_conf_dict = {}
#     confidence_sets_list_dict = {}
#     thetas = [-8, -4, 0, 4, 8]
#     for theta_val in thetas:
#         try:
#             with open(f'results/results_revised/confidence_sets_schema_10_obs_x_theta_{theta_val}.pkl', 'rb') as f:
#                 data = pickle.load(f)
#                 obs_x_samples_conf_dict[theta_val] = data['obs_x_samples']
#                 confidence_sets_list_dict[theta_val] = data['confidence_sets_list']
#         except FileNotFoundError:
#             obs_theta_val = torch.tensor([[theta_val]], dtype=torch.float32)
#             obs_x_samples = simulator.likelihood(obs_theta_val).sample(sample_shape=(10, )).squeeze(-1)
#             confidence_sets_many = lf2i.inference(
#                 x=obs_x_samples,
#                 evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
#                 confidence_level=CONFIDENCE_LEVEL,
#                 calibration_method='p-values',
#                 calibration_model='cat-gb',
#                 calibration_model_kwargs={'iterations': 1000, 'depth': 9},
#                 # calibration_model_kwargs={
#                 #     'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
#                 #     'n_iter': 25
#                 # },
#                 T_prime=(b_prime_params, b_prime_samples),
#                 num_augment=10
#             )
#             obs_x_samples_conf_dict[theta_val] = obs_x_samples
#             confidence_sets_list_dict[theta_val] = confidence_sets_many
#             with open(f'results/results_revised/confidence_sets_schema_10_obs_x_theta_{theta_val}.pkl', 'wb') as f:
#                 pickle.dump({
#                     'obs_x_samples': obs_x_samples,
#                     'confidence_sets_list': confidence_sets_many
#                 }, f)
#     confidence_sets_schema_10_obs_x = {
#         'obs_x_samples_conf_dict': obs_x_samples_conf_dict,
#         'confidence_sets_list_dict': confidence_sets_list_dict
#     }
# except Exception as e:
#     print(f"Error: {e}")

In [None]:
# with open('results/results_revised_old/lf2i_obj_schema.pkl', 'rb') as f:
#     lf2i = pickle.load(f)

try:
    with open('results/results_revised/schema_eval_grid.pkl', 'rb') as f:
        eval_grid = pickle.load(f)
    with open('results/results_revised/schema_posterior_xobs.pkl', 'rb') as f:
        posterior_xobs = pickle.load(f)
    with open('results/results_revised/schema_p_values.pkl', 'rb') as f:
        p_values = pickle.load(f)
except FileNotFoundError:
    eval_grid = simulator.reference.sample(sample_shape=(EVAL_GRID_SIZE, ))
    posterior_xobs = posterior.log_prob(theta=eval_grid, x=obs_x, norm_posterior=NORM_POSTERIOR_SAMPLES)
    p_values = lf2i.calibration_model['multiple_levels'].predict_proba(
        X=preprocess_predict_p_values('confidence_sets', posterior_xobs, eval_grid, lf2i.calibration_model['multiple_levels'])
    )[:, 1]
    with open('results/results_revised/schema_eval_grid.pkl', 'wb') as f:
        pickle.dump(eval_grid, f)
    with open('results/results_revised/schema_posterior_xobs.pkl', 'wb') as f:
        pickle.dump(posterior_xobs, f)
    with open('results/results_revised/schema_p_values.pkl', 'wb') as f:
        pickle.dump(p_values, f)


### DIAGNOSTICS

In [None]:
try:
    with open('results/results_revised/schema_b_double_prime.pkl', 'rb') as f:
        schema_diagnostics_samples = pickle.load(f)
        b_double_prime_params = schema_diagnostics_samples['b_double_prime_params']
        b_double_prime_samples = schema_diagnostics_samples['b_double_prime_samples']
except FileNotFoundError:
    b_double_prime_params, b_double_prime_samples = simulator.simulate_for_diagnostics(size=B_DOUBLE_PRIME)
    with open('results/results_revised/schema_b_double_prime.pkl', 'wb') as f:
        pickle.dump({
            'b_double_prime_params': b_double_prime_params,
            'b_double_prime_samples': b_double_prime_samples
        }, f)

b_double_prime_params.shape, b_double_prime_samples.shape

In [None]:
plt.hist(b_double_prime_params)
plt.show()

In [None]:
plt.hist2d(b_double_prime_params.squeeze(), b_double_prime_samples.squeeze(), bins=100, cmap='viridis')
plt.show()

For FreB

In [None]:
try:
    with open('results/results_revised/diagn_confidence.pkl', 'rb') as f:
        diagn_objects_conf = pickle.load(f)
except FileNotFoundError:
    diagn_objects_conf = {}
    for cl in CONFIDENCE_LEVEL:
        diagnostics_estimator_confset, out_parameters_confset, mean_proba_confset, upper_proba_confset, lower_proba_confset = lf2i.diagnostics(
            region_type='lf2i',
            confidence_level=cl,
            calibration_method='p-values',
            coverage_estimator='splines',
            T_double_prime=(b_double_prime_params, b_double_prime_samples),
        )
        diagn_objects_conf[cl] = (diagnostics_estimator_confset, out_parameters_confset, mean_proba_confset, upper_proba_confset, lower_proba_confset)
    with open('results/results_revised/diagn_confidence.pkl', 'wb') as f:
        pickle.dump(diagn_objects_conf, f)

In [None]:
_, out_params, mean_proba, _, _ = diagn_objects_conf[CONFIDENCE_LEVEL[0]]
plt.scatter(out_params, mean_proba, color='blue', label='Mean Probability')
plt.ylim(0, 1)
plt.show()

In [None]:
# _, out_params, mean_proba, _, _ = diagn_objects_conf[CONFIDENCE_LEVEL[1]]
# plt.scatter(out_params, mean_proba, color='blue', label='Mean Probability')
# plt.ylim(0, 1)
# plt.show()

For HPD

In [None]:
try:
    with open('results/results_revised/diagn_credible.pkl', 'rb') as f:
        diagn_objects_cred = pickle.load(f)
except FileNotFoundError:
    diagn_objects_cred = {}
    for cl in CONFIDENCE_LEVEL:
        diagnostics_estimator_credible, out_parameters_credible, mean_proba_credible, upper_proba_credible, lower_proba_credible, _ = lf2i.diagnostics(
            region_type='posterior',
            confidence_level=cl,
            coverage_estimator='splines',
            T_double_prime=(b_double_prime_params, b_double_prime_samples),
            posterior_estimator=posterior,
            evaluation_grid=eval_grid[::30],  # 1/30 of size just to speed up for drawing schema
            num_level_sets=5_000,
            norm_posterior=NORM_POSTERIOR_SAMPLES
        )
        diagn_objects_cred[cl] = (diagnostics_estimator_credible, out_parameters_credible, mean_proba_credible, upper_proba_credible, lower_proba_credible)
    with open('results/results_revised/diagn_credible.pkl', 'wb') as f:
        pickle.dump(diagn_objects_cred, f)

In [None]:
_, out_params, mean_proba, _, _ = diagn_objects_cred[CONFIDENCE_LEVEL[0]]
plt.scatter(out_params, mean_proba, color='blue', label='Mean Probability')
plt.ylim(0, 1)
plt.show()

In [None]:
# _, out_params, mean_proba, _, _ = diagn_objects_cred[CONFIDENCE_LEVEL[1]]
# plt.scatter(out_params, mean_proba, color='blue', label='Mean Probability')
# plt.ylim(0, 1)
# plt.show()

### FIGURE PREAMBLE

In [None]:
plt.rc('text', usetex=True)  # Enable LaTeX
plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
plt.rcParams['text.latex.preamble'] = r'''
    \usepackage{amsmath}  % For \mathbb
    \usepackage{amssymb}  % For \mathbb
    \usepackage{bm}       % For bold math symbols
    \usepackage{underscore} % If underscores are needed
'''
plt.rcParams["figure.facecolor"] = 'gainsboro'

# fig, ax = plt.subplots(2, 2, figsize=(19, 13))
# plt.subplots_adjust(hspace=0.32)
# plot1(ax[0][0])
# plot2(ax[1][0])
# plot4(ax[0][1])
# plot3(ax[1][1])
# fig.patches.append(patches.Rectangle((0.097, 0.035), 0.818, 0.428, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))
# fig.patches.append(patches.Rectangle((0.097, 0.477), 0.818, 0.488, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))
# fig.text(0.25, 0.488, r"{\bf RESHAPE}", fontsize=20, color='black', ha='center', va='center', bbox=dict(facecolor='white', edgecolor='red', boxstyle='round,pad=0.4', linewidth=3), zorder=20)
# fig.add_artist(FancyArrowPatch(
#     posA=(0.25, 0.55), posB=(0.25, 0.43),
#     connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
#     color='red', linewidth=6, zorder=10
# ))
# # ax[0].set_title(r'Neural Density Estimation and'+'\n'r'High-Posterior-Density Sets', size=25, pad=24)
# fig.suptitle(r'\textbf{1D Synthetic Example}', size=40, y=1.01)
# ax[0][0].set_title(r'\textbf{From Highest Posterior Density}'+'\n'+r'\textbf{to Frequentist-Bayes Intervals}', size=25, pad=26)
# ax[0][1].set_title(r'\textbf{Local Diagnostics}', size=25, pad=37)  # r"{\bf B)} " + 
# # ax[1].set_title(r'Valid Scientific Inference via'+'\n'+r'Frequentist-Bayes Sets', size=25, pad=24)
# # fig.text(0.16, 0.925, r"{\bf A)}", fontsize=25, color='black', ha='center', va='center', zorder=20)
# # fig.text(0.639, 0.99, r"{\bf b)}", fontsize=25, color='black', ha='center', va='center', zorder=20)
# # plt.tight_layout(pad=3)

# # plt.savefig('./outputs/Figure1_draft.pdf', bbox_inches='tight')
# # plt.savefig('./outputs/Figure1_draft.png', bbox_inches='tight')
# plt.show()

### PANEL A

In [None]:
def plot_A_left(axis=None):
    """
    Modified version of plot1 ensuring all elements properly start from 
    the bottom of the subplot when ymin=0 is specified.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    from matplotlib.lines import Line2D
    import seaborn as sns
    import numpy as np
    import torch
    
    # Since we don't have the original data and custom classes, we'll add placeholders
    # In a real implementation, replace these with your actual data
    np.random.seed(42)
    
    # Create custom handler for legend patches
    class DividedPatchHandler:
        def __init__(self, edgecolors, facecolors, num_patches=2):
            self.edgecolors = edgecolors
            self.facecolors = facecolors
            self.num_patches = num_patches
            
        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            x0, y0 = handlebox.xdescent, handlebox.ydescent
            width, height = handlebox.width, handlebox.height
            patch_width = width / self.num_patches
            
            patches = []
            for i in range(self.num_patches):
                patch = mpatches.Rectangle(
                    [x0 + i * patch_width, y0], 
                    patch_width, height, 
                    facecolor=self.facecolors[i],
                    edgecolor=self.edgecolors[i],
                    transform=handlebox.get_transform()
                )
                handlebox.add_artist(patch)
                patches.append(patch)
            
            return patches

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax = fig.gca()
    else:
        ax = axis
    
    prior_color = 'lightgrey'
    truth_color = 'crimson'

    # CRITICAL FIX: Set the default transform explicitly for data coordinates
    # This ensures that when ymin=0 is specified, it refers to the data coordinate system
    data_transform = ax.transData
    
    # Plot densities with extended range
    # Use fill_between to ensure KDE plots start from zero
    x_range = np.linspace(-7, 7, 1000)
    
    # For prior KDE - ensure it extends to the axis
    sns.kdeplot(prior_samples.squeeze(), ax=ax, color=prior_color, linewidth=4, zorder=10, 
                linestyle='--', clip_on=True)
    
    # Ensure the y-axis starts at zero for proper coordinate mapping
    ymin_data = 0
    ax.set_ylim(bottom=ymin_data)
    
    # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top
    # Explicitly use data coordinates with full height
    ax.axvline(x=obs_theta.item(), ymin=0, ymax=1, 
              color=truth_color, 
              linestyle='--', linewidth=4, zorder=30)
    
    # Star marker for truth
    ax.text(obs_theta.item() - 0.2, 0.17, s=r'\textbf{True parameter $\theta^*$}',
            transform=ax.get_xaxis_transform(), 
            horizontalalignment='right', verticalalignment='center', 
            zorder=50, fontdict={'size': 25})

    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (0.8, 0.13)
    arrow_end_xy = (obs_theta.item(), 0.02)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.1",  # negative for downward curve
        transform=ax.get_xaxis_transform(),
        zorder=60
    )
    ax.add_patch(curved_arrow)

    ax.scatter(obs_theta.item(), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='white', edgecolor='white', s=300, linewidth=2, 
              zorder=50, clip_on=False)
    ax.scatter(obs_theta.item(), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='none', edgecolor=truth_color, s=300, linewidth=2, 
              zorder=50, clip_on=False)

    # Text labeling
    ax.text(-2.0, 0.73, s=r'Training prior', transform=ax.get_xaxis_transform(), 
           horizontalalignment='center', verticalalignment='center', 
           zorder=50, fontdict={'size': 25})

    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (-2.0, 0.71)
    arrow_end_xy = (-0.9, 0.5)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.3",  # negative for downward curve
        transform=ax.get_xaxis_transform(),
        zorder=60
    )
    ax.add_patch(curved_arrow)


    # Set axis properties
    ax.set_ylabel('Plausibility', size=25)
    ax.set_xlabel(r'$\theta$', size=27)
    ax.set_title(r'\textbf{Parameter space}', size=25, ha='center', va='center',)
    # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', 
    #        zorder=100, transform=ax.transAxes, 
    #        bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.12', linewidth=0))
    
    ax.set_xlim(-7, 7)
    ax.set_ylim(0, 0.4)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)
    ax.spines['top'].set_visible(False)  
    ax.spines['right'].set_visible(False)  
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['bottom'].set_zorder(40)
    
    # Axis arrows
    ax.plot(7, 0, ">k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10, zorder=101) 
    ax.plot(-7, 1, "^k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10) 
    if axis is None:
        # plt.savefig('schema_posterior_fixed.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax

In [None]:
def plot_A_topright(axis=None, obs_theta=None):
    obs_samples = simulator.likelihood(obs_theta).sample(sample_shape=(100_000, ))
    obs_x_samples = obs_theta + torch.tensor([-0.25])

    # Plot histogram of obs_samples in goldenrod
    sns.kdeplot(obs_samples.numpy().flatten(), ax=axis, color='goldenrod', linewidth=4, zorder=10, 
                linestyle='-', clip_on=True)
                #     axis.hist(obs_samples.numpy().flatten(), bins=20, facecolor='goldenrod', edgecolor='black', alpha=0.2, label='obs_samples',)
    # axis.scatter(obs_x_samples.numpy()[:5], np.zeros_like(obs_x_samples.numpy()[:5]),
    #             marker='X', facecolor='white', edgecolor='goldenrod', s=300, linewidth=2, zorder=50, clip_on=False)
    # axis.set_xlabel(r'Data', fontsize=18, labelpad=10)
    # for idx, x_val in enumerate(obs_x_samples.numpy()[:5]):
    #     if idx == 0:
    #         axis.annotate(
    #         r'$X_\mathrm{{obs}}=X_{{{}}}$'.format(idx + 1),
    #         (x_val, 0), xytext=(0, 18), textcoords='offset points',
    #         ha='right', va='bottom', fontsize=12, color='k', fontweight='bold'
    #         )
    #     else:
    #         axis.annotate(
    #         r'$X_{{{}}}$'.format(idx + 1),
    #         (x_val, 0), xytext=(0, 18), textcoords='offset points',
    #         ha='center', va='bottom', fontsize=12, color='k', fontweight='bold'
    #         )
    # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top
    # Explicitly use data coordinates with full height

    # axis.axvline(x=obs_x_samples.item(), ymin=0, ymax=1, 
    #           color='goldenrod', 
    #           linestyle='--', linewidth=4, zorder=30)
    
    # # Star marker for truth
    # axis.text(
    #     obs_x_samples.item() - 1.5, 0.075, r'$X_{\textrm{obs}}$', 
    #     ha='center', va='bottom', fontsize=16, color='k', fontweight='bold',
    #     transform=axis.get_xaxis_transform(),
    #     )
    # axis.scatter(obs_x_samples.item(), 0, 
    #           transform=axis.get_xaxis_transform(), marker='X', 
    #           facecolor='white', edgecolor='goldenrod', s=300, linewidth=2, 
    #           zorder=50, clip_on=False)
    # Text labeling
    axis.text(0.0, 0.73, s=r'Likelihood' + '\n' + r'for true $\theta^*$', transform=axis.get_xaxis_transform(), 
           horizontalalignment='center', verticalalignment='center', 
           zorder=50, fontdict={'size': 22})

    axis.set_ylabel('Probability density', fontsize=18)
    axis.set_xlabel(r'$X$', fontsize=24, labelpad=10)
    axis.set_title(r'\textbf{Observable data}', fontsize=22)
    axis.set_xlim(-7, 7)
    axis.set_xticks([])
    axis.set_yticks([])
    axis.spines['top'].set_visible(False)  
    axis.spines['right'].set_visible(False)  
    axis.spines['left'].set_linewidth(2)
    axis.spines['bottom'].set_linewidth(2)
    axis.spines['bottom'].set_zorder(40)

    # axis arrows
    axis.plot(7, 0, ">k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101) 
    axis.plot(-7, 1, "^k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10)

In [None]:
def plot_A_bottomright(axis=None, obs_theta=None):
    obs_samples = simulator.likelihood(obs_theta).sample(sample_shape=(100_000, ))
    obs_x_samples = obs_theta + torch.tensor([-3, -1.0, 0.25, 1.0, 1.75])

    # Plot histogram of obs_samples in goldenrod
    axis.hist(obs_samples.numpy().flatten(), bins=20, facecolor='goldenrod', edgecolor='brown', linewidth=2, alpha=0.5, label='obs_samples',)
    pos_rectangle_offset, neg_rectangle_offset = -0.22, -0.28
    credible_min, credible_max = torch.quantile(obs_samples, torch.tensor([0.025, 0.975])).numpy()
    # HPD set indicators (still using axis transform as this is intentional for placement)
    axis.plot([credible_min, credible_max], [pos_rectangle_offset]*2, transform=axis.get_xaxis_transform(), color='goldenrod', linewidth=4, zorder=51, clip_on=False)
    axis.plot([credible_min, credible_max], [neg_rectangle_offset]*2, transform=axis.get_xaxis_transform(), color='goldenrod', linewidth=4, zorder=51, clip_on=False)
    axis.vlines([credible_min, credible_max], ymin=[neg_rectangle_offset]*2, ymax=[pos_rectangle_offset]*2, transform=axis.get_xaxis_transform(), color='goldenrod', linewidth=4, zorder=50, clip_on=False)
    axis.fill_between([credible_min, credible_max], neg_rectangle_offset, pos_rectangle_offset, transform=axis.get_xaxis_transform(), color='goldenrod', alpha=0.3, zorder=50, clip_on=False)
    axis.vlines(x=[credible_min, credible_max], ymin=pos_rectangle_offset, ymax=0, transform=axis.get_xaxis_transform(), color='goldenrod', linestyle=':', linewidth=3, zorder=29, clip_on=False)
    axis.text(-2.5, -0.3, s=r'\textbf{Prediction interval}', transform=axis.get_xaxis_transform(), 
           horizontalalignment='right', verticalalignment='center', 
           zorder=50, fontdict={'size': 22},
        #    bbox=dict(facecolor='white', boxstyle='round,pad=0.12', linewidth=0, alpha=0.5)
           )
    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (-2.5, -0.3)
    arrow_end_xy = (credible_min, -0.24)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.1",  # negative for downward curve
        transform=axis.get_xaxis_transform(),
        zorder=60,
        clip_on=False  # Ensure the arrow is not clipped
    )
    # ax.set_clip_on(False)  # Ensure the arrow is not clipped
    axis.add_patch(curved_arrow)

    axis.set_ylabel('Probability density', fontsize=18)
    axis.set_xlabel(r'$X$', fontsize=24, labelpad=10)
    axis.set_title(r'\textbf{Simulated data}', fontsize=22)
    axis.set_xlim(-7, 7)
    axis.set_xticks([])
    axis.set_yticks([])
    axis.spines['top'].set_visible(False)  
    axis.spines['right'].set_visible(False)  
    axis.spines['left'].set_linewidth(2)
    axis.spines['bottom'].set_linewidth(2)
    axis.spines['bottom'].set_zorder(40)

    # axis arrows
    axis.plot(7, 0, ">k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101) 
    axis.plot(-7, 1, "^k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10)

In [None]:
from matplotlib.patches import FancyArrowPatch
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(12, 5))
gs = gridspec.GridSpec(2, 2, width_ratios=[1.8, 1], height_ratios=[1, 1], wspace=0.6, hspace=0.5)

# Large left subplot
ax_left = fig.add_subplot(gs[:, 0])
plot_A_left(axis=ax_left)

# Two small right subplots stacked vertically
ax_right_top = fig.add_subplot(gs[0, 1])
ax_right_bottom = fig.add_subplot(gs[1, 1])

# You can call plot_A_right or any other plotting function for these
plot_A_topright(axis=ax_right_top, obs_theta=obs_theta)
# For demonstration, you can plot_A_right again or another function
plot_A_bottomright(axis=ax_right_bottom, obs_theta=obs_theta)

plt.suptitle(r'\textbf{Forward problem}', fontsize=36, y=1.08)
# Draw a black horizontal arrow pointing to the top right subplot with annotation "Nature"

# Get the bounding box of the top right subplot in figure coordinates
bbox = ax_right_top.get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox.x0 - 0.18, bbox.y0 + bbox.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox.x0 - 0.02, bbox.y0 + bbox.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

# Add annotation "Nature" above the arrow
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2 - 0.02,
    arrow_start[1] + 0.02,  # Slightly above the arrow
    "Nature",
    ha='center', va='bottom',
    fontsize=25,
    color='black',
    fontweight='bold'
)

# Get the bounding box of the top right subplot in figure coordinates
bbox = ax_right_bottom.get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox.x0 - 0.18, bbox.y0 + bbox.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox.x0 - 0.02, bbox.y0 + bbox.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

# Add annotation "Nature" above the arrow
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2 - 0.02,
    arrow_start[1] + 0.02,  # Slightly above the arrow
    "Model of Nature",
    ha='center', va='bottom',
    fontsize=25,
    color='black',
    fontweight='bold'
)

plt.savefig('panel_a.pdf', bbox_inches='tight', dpi=300)
plt.savefig('panel_a.png', bbox_inches='tight', dpi=300)
plt.show()

### PANEL B

In [None]:
def plot_B_left(axis=None):
    obs_samples = simulator.likelihood(obs_theta).sample(sample_shape=(100_000, ))
    obs_x_samples = obs_theta + torch.tensor([-0.25])

    # Plot histogram of obs_samples in goldenrod
    # axis.hist(obs_samples.numpy().flatten(), bins=20, facecolor='goldenrod', edgecolor='black', alpha=0.2, label='obs_samples',)
    sns.kdeplot(obs_samples.numpy().flatten(), ax=axis, color='yellowgreen', linewidth=4, zorder=10,
                linestyle='-', clip_on=True, fill=False, alpha=1.0, label='obs_samples')
    # axis.scatter(obs_x_samples.numpy()[:5], np.zeros_like(obs_x_samples.numpy()[:5]),
    #             marker='X', facecolor='white', edgecolor='goldenrod', s=300, linewidth=2, zorder=50, clip_on=False)
    # axis.set_xlabel(r'Data', fontsize=25, labelpad=10)
    # for idx, x_val in enumerate(obs_x_samples.numpy()[:5]):
    #     if idx == 0:
    #         axis.annotate(
    #         r'$X_\mathrm{{obs}}=X_{{{}}}$'.format(idx + 1),
    #         (x_val, 0), xytext=(0, 18), textcoords='offset points',
    #         ha='right', va='bottom', fontsize=12, color='k', fontweight='bold'
    #         )
    #     else:
    #         axis.annotate(
    #         r'$X_{{{}}}$'.format(idx + 1),
    #         (x_val, 0), xytext=(0, 18), textcoords='offset points',
    #         ha='center', va='bottom', fontsize=12, color='k', fontweight='bold'
    #         )
    axis.axvline(x=obs_x_samples.item(), ymin=0, ymax=1, 
              color='goldenrod', 
              linestyle='--', linewidth=4, zorder=30)
    
    # Star marker for truth
    axis.text(
        obs_x_samples.item() - 4.2, 0.14, r'\textbf{Observed data $X_{\textrm{obs}}$}', 
        ha='center', va='bottom', fontsize=22, color='k', fontweight='bold',
        transform=axis.get_xaxis_transform(), zorder=40,
        )
    axis.scatter(obs_x_samples.item(), 0, 
              transform=axis.get_xaxis_transform(), marker='X', 
              facecolor='white', edgecolor='goldenrod', s=300, linewidth=2, 
              zorder=50, clip_on=False)

    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (0.8, 0.13)
    arrow_end_xy = (obs_x_samples.item() - 0.3, 0.02)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.1",  # negative for downward curve
        transform=axis.get_xaxis_transform(),
        zorder=60
    )
    axis.add_patch(curved_arrow)

    # Text labeling
    axis.text(-1.0, 0.75, s=r'Data distribution', transform=axis.get_xaxis_transform(), 
           horizontalalignment='center', verticalalignment='center', 
           zorder=50, fontdict={'size': 22})

    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (0.0, 0.71)
    arrow_end_xy = (obs_x_samples.item() - 0.5, 0.65)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.3",  # negative for downward curve
        transform=axis.get_xaxis_transform(),
        zorder=60
    )
    axis.add_patch(curved_arrow)

    axis.set_ylabel('Probability density', fontsize=25)
    axis.set_xlabel(r'$X$', fontsize=25, labelpad=10)
    axis.set_title(r'\textbf{Observable data}', fontsize=25)
    axis.set_xlim(-7, 7)
    axis.set_xticks([])
    axis.set_yticks([])
    axis.spines['top'].set_visible(False)  
    axis.spines['right'].set_visible(False)  
    axis.spines['left'].set_linewidth(2)
    axis.spines['bottom'].set_linewidth(2)
    axis.spines['bottom'].set_zorder(40)

    # axis arrows
    axis.plot(7, 0, ">k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101) 
    axis.plot(-7, 1, "^k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10)

In [None]:
def plot_B_topright(axis=None, obs_theta=None):
    """
    Modified version of plot1 ensuring all elements properly start from 
    the bottom of the subplot when ymin=0 is specified.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    from matplotlib.lines import Line2D
    import seaborn as sns
    import numpy as np
    import torch
    
    # Since we don't have the original data and custom classes, we'll add placeholders
    # In a real implementation, replace these with your actual data
    np.random.seed(42)
    
    # Create custom handler for legend patches
    class DividedPatchHandler:
        def __init__(self, edgecolors, facecolors, num_patches=2):
            self.edgecolors = edgecolors
            self.facecolors = facecolors
            self.num_patches = num_patches
            
        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            x0, y0 = handlebox.xdescent, handlebox.ydescent
            width, height = handlebox.width, handlebox.height
            patch_width = width / self.num_patches
            
            patches = []
            for i in range(self.num_patches):
                patch = mpatches.Rectangle(
                    [x0 + i * patch_width, y0], 
                    patch_width, height, 
                    facecolor=self.facecolors[i],
                    edgecolor=self.edgecolors[i],
                    transform=handlebox.get_transform()
                )
                handlebox.add_artist(patch)
                patches.append(patch)
            
            return patches

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax = fig.gca()
    else:
        ax = axis
    
    prior_color = 'lightgrey'
    truth_color = 'crimson'

    # CRITICAL FIX: Set the default transform explicitly for data coordinates
    # This ensures that when ymin=0 is specified, it refers to the data coordinate system
    data_transform = ax.transData
    
    # Plot densities with extended range
    # Use fill_between to ensure KDE plots start from zero
    x_range = np.linspace(-7, 7, 1000)
    
    # # For prior KDE - ensure it extends to the axis
    # sns.kdeplot(prior_samples.squeeze(), ax=ax, color=prior_color, linewidth=2, zorder=10, 
    #             linestyle='--', clip_on=True)
    
    # Ensure the y-axis starts at zero for proper coordinate mapping
    ymin_data = 0
    ax.set_ylim(bottom=ymin_data)
    
    # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top
    # Explicitly use data coordinates with full height
    ax.axvline(x=obs_theta.item(), ymin=0, ymax=1, 
              color=truth_color, 
              linestyle='--', linewidth=4, zorder=30)
    
    # Star marker for truth
    ax.scatter(obs_theta.item(), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='white', edgecolor='white', s=300, linewidth=2, 
              zorder=50, clip_on=False)
    ax.scatter(obs_theta.item(), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='none', edgecolor=truth_color, s=300, linewidth=2, 
              zorder=50, clip_on=False)

    # Star marker for truth
    ax.text(-6, 0.25, s=r'\textbf{True parameter $\theta^*$}',
            transform=ax.get_xaxis_transform(), 
            horizontalalignment='left', verticalalignment='center', 
            zorder=50, fontdict={'size': 20})

    # # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    # arrow_start_xy = (0.8, 0.18)
    # arrow_end_xy = (obs_theta.item(), 0)
    # # Use the same start and end as before, but add a curve with connectionstyle
    # curved_arrow = FancyArrowPatch(
    #     posA=arrow_start_xy, 
    #     posB=arrow_end_xy,
    #     arrowstyle="->",
    #     color="k",
    #     linewidth=2,
    #     mutation_scale=25,
    #     connectionstyle="arc3,rad=0.1",  # negative for downward curve
    #     transform=ax.get_xaxis_transform(),
    #     zorder=60
    # )
    # ax.add_patch(curved_arrow)

    # # Text labeling
    # ax.text(-2.0, 0.73, s=r'Prior / train dist.', transform=ax.get_xaxis_transform(), 
    #        horizontalalignment='center', verticalalignment='center', 
    #        zorder=50, fontdict={'size': 22})


    # Set axis properties
    ax.set_xlabel(r'$\theta$', size=24)
    ax.set_title(r'\textbf{Parameter space}', fontsize=22)
    # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', 
    #        zorder=100, transform=ax.transAxes, 
    #        bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.12', linewidth=0))
    
    ax.set_xlim(-7, 7)
    ax.set_ylim(0, 0.6)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)
    ax.spines['top'].set_visible(False)  
    ax.spines['right'].set_visible(False)  
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['bottom'].set_zorder(40)
    
    # Axis arrows
    ax.plot(7, 0, ">k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10, zorder=101) 
    # ax.plot(-7, 1, "^k", transform=ax.get_xaxis_transform(), 
    #        clip_on=False, markersize=10) 
    if axis is None:
        # plt.savefig('schema_posterior_fixed.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax

In [None]:
def plot_B_bottomright(axis=None):
    """
    Modified version of plot1 ensuring all elements properly start from 
    the bottom of the subplot when ymin=0 is specified.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    from matplotlib.lines import Line2D
    import seaborn as sns
    import numpy as np
    import torch
    
    # Since we don't have the original data and custom classes, we'll add placeholders
    # In a real implementation, replace these with your actual data
    np.random.seed(42)
    
    # Calculate credible sets for this example
    credible_sets = []
    for cl in CONFIDENCE_LEVEL:
        lower = np.percentile(posterior_samples, (1-cl)*100/2)
        upper = np.percentile(posterior_samples, 100 - (1-cl)*100/2)
        credible_sets.append(np.array([lower, upper]))
    
    # Create custom handler for legend patches
    class DividedPatchHandler:
        def __init__(self, edgecolors, facecolors, num_patches=2):
            self.edgecolors = edgecolors
            self.facecolors = facecolors
            self.num_patches = num_patches
            
        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            x0, y0 = handlebox.xdescent, handlebox.ydescent
            width, height = handlebox.width, handlebox.height
            patch_width = width / self.num_patches
            
            patches = []
            for i in range(self.num_patches):
                patch = mpatches.Rectangle(
                    [x0 + i * patch_width, y0], 
                    patch_width, height, 
                    facecolor=self.facecolors[i],
                    edgecolor=self.edgecolors[i],
                    transform=handlebox.get_transform()
                )
                handlebox.add_artist(patch)
                patches.append(patch)
            
            return patches
    
    # Setup plot
    plt.rc('text', usetex=True)  # Enable LaTeX
    plt.rc('font', family='serif')  # Use a serif font
    plt.rcParams['text.latex.preamble'] = r'''
        \usepackage{amsmath}  % For \mathbb
        \usepackage{amssymb}  % For \mathbb
        \usepackage{bm}       % For bold math symbols
        \usepackage{underscore} % If underscores are needed
    '''

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax = fig.gca()
    else:
        ax = axis
        
    posterior_color = 'purple'
    prior_color = 'lightgrey'
    truth_color = 'crimson'

    # CRITICAL FIX: Set the default transform explicitly for data coordinates
    # This ensures that when ymin=0 is specified, it refers to the data coordinate system
    data_transform = ax.transData
    
    # Plot densities with extended range
    # Use fill_between to ensure KDE plots start from zero
    x_range = np.linspace(-7, 7, 1000)
    
    # # For prior KDE - ensure it extends to the axis
    # sns.kdeplot(prior_samples.squeeze(), ax=ax, color=prior_color, linewidth=2, zorder=10, 
    #             linestyle='--', clip_on=True)
    
    # For posterior KDE
    kde_posterior = sns.kdeplot(posterior_samples.squeeze(), color='black', ax=ax, linewidth=2, 
                               zorder=30, clip_on=True)
    
    # Get posterior curve data for HPD calculations
    posterior_line = kde_posterior.get_lines()[-1]
    x_posterior = posterior_line.get_xdata()
    y_posterior = posterior_line.get_ydata()
    
    # Add minimum density points to complete the KDEs at the bottom
    # This ensures all KDE curves visually start at y=0
    ax.plot([-7, x_posterior[0]], [0, y_posterior[0]], color='black', linewidth=2, zorder=30)
    ax.plot([x_posterior[-1], 7], [y_posterior[-1], 0], color='black', linewidth=2, zorder=30)
    
    # Ensure the y-axis starts at zero for proper coordinate mapping
    ymin_data = 0
    ax.set_ylim(bottom=ymin_data)
    
    # Create HPD sets visualization
    colors = [posterior_color, 'hotpink']
    # pos_rectangle_offset, neg_rectangle_offset = 0.02, -0.02
    
    for i, cl in enumerate(CONFIDENCE_LEVEL):
        pos_rectangle_offset, neg_rectangle_offset = (-0.16, -0.20) # if i == 1 else (-0.19, -0.24) # (-0.04, -0.09) if i == 1 else (-0.14, -0.19)

        credible_min = credible_sets[i].min()
        credible_max = credible_sets[i].max()
        intersect_point = np.interp([credible_min, credible_max], x_posterior, y_posterior)

        # HPD set indicators (still using axis transform as this is intentional for placement)
        ax.plot([credible_min, credible_max], [pos_rectangle_offset]*2, 
                transform=ax.get_xaxis_transform(), color=colors[i], linewidth=4, 
                zorder=50, clip_on=False)
        ax.plot([credible_min, credible_max], [neg_rectangle_offset]*2, 
                transform=ax.get_xaxis_transform(), color=colors[i], linewidth=4, 
                zorder=50, clip_on=False)
        ax.vlines([credible_min, credible_max], ymin=[neg_rectangle_offset]*2, 
                  ymax=[pos_rectangle_offset]*2, transform=ax.get_xaxis_transform(), 
                  color=colors[i], linewidth=4, zorder=50, clip_on=False)
        ax.fill_between([credible_min, credible_max], neg_rectangle_offset, pos_rectangle_offset, 
                       transform=ax.get_xaxis_transform(), color=colors[i], alpha=0.3, 
                       zorder=50, clip_on=False)
        
        if i == 0:
            ax.fill_between([credible_min, credible_max], neg_rectangle_offset, pos_rectangle_offset, 
                           transform=ax.get_xaxis_transform(), color="white", alpha=1.0, 
                           zorder=49, clip_on=False)
            
        # CRITICAL FIX: PROJECTION LINES 
        # Use explicit data coordinates to ensure lines start from bottom
        # Important: Don't use transform here to keep in data coordinates
        ax.vlines(x=[credible_min, credible_max], ymin=0.5 * neg_rectangle_offset, ymax=intersect_point, 
                  transform=data_transform, color=colors[i], linestyle=':', 
                 linewidth=3, zorder=29, clip_on=False)
        
        # CRITICAL FIX: SHADING UNDER POSTERIOR FOR HPD
        # Create complete polygon from bottom to curve and back to bottom
        mask = (x_posterior >= credible_min) & (x_posterior <= credible_max)
        x_hpd = x_posterior[mask]
        y_hpd = y_posterior[mask]
        
        # Add points at the base of the intersect point to create a closed polygon
        x_polygon = np.concatenate([[credible_min], x_hpd, [credible_max], [credible_min]])
        y_polygon = np.concatenate([[intersect_point[0]], y_hpd, [intersect_point[1]], [intersect_point[0]]])
        
        # Plot the filled polygon
        ax.fill(x_polygon, y_polygon, color=colors[i], alpha=0.4, 
               zorder=28+i, linestyle='--', clip_on=True)
        
        # Horizontal level-set lines for HPDs
        ax.hlines(y=intersect_point[0], 
                 xmin=((credible_sets[0].min()+credible_sets[0].max())/2)-2.9, 
                 xmax=((credible_sets[0].min()+credible_sets[0].max())/2)+2.9, 
                 linestyle='--', linewidth=2, color='k', zorder=28+i)
        
        # # Text labels
        # ax.text(((credible_sets[0].min()+credible_sets[0].max())/2)+3.05, intersect_point[0], 
        #         s=r'$c_{sub}$'.format(sub=i+1),
        #         horizontalalignment='left', verticalalignment='center', 
        #         zorder=50, fontdict={'size': 20})
        
        # # TEXT FOR CONFIDENCE LEVEL
        # ax.text((credible_min+credible_max)/2, intersect_point[0]+(0.07 if i == 1 else 0.12), 
        #         s=f'{cl*100:.0f}\%', 
        #         horizontalalignment='center', verticalalignment='center', 
        #         zorder=50, fontdict={'size': 16})

        # ax.text(credible_max + 0.22, 0.5*(neg_rectangle_offset+pos_rectangle_offset),
        #         s=r'${cl:.0f}\%\text{{ Credible interval}}$'.format(cl=100*cl),
        #         transform=ax.get_xaxis_transform(), horizontalalignment='left', verticalalignment='center', 
        #         zorder=50, fontdict={'size': 20})

    # # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top
    # # Explicitly use data coordinates with full height
    # ax.axvline(x=float(obs_theta.reshape(1).numpy()), ymin=0, ymax=1, 
    #           color=truth_color, 
    #           linestyle='--', linewidth=4, zorder=30)
    
    # # Star marker for truth
    # ax.scatter(float(obs_theta.reshape(1).numpy()), 0, 
    #           transform=ax.get_xaxis_transform(), marker='*', 
    #           facecolor='white', edgecolor='white', s=300, linewidth=2, 
    #           zorder=50, clip_on=False)
    # ax.scatter(float(obs_theta.reshape(1).numpy()), 0, 
    #           transform=ax.get_xaxis_transform(), marker='*', 
    #           facecolor='none', edgecolor=truth_color, s=300, linewidth=2, 
    #           zorder=50, clip_on=False)

    # Text labeling
    # ax.text(-1.0, 0.68, s=r'Prior', transform=ax.get_xaxis_transform(), 
    #        horizontalalignment='center', verticalalignment='center', 
    #        zorder=50, fontdict={'size': 16})
    ax.text(-1.6, 0.92, s=r'Posterior', transform=ax.get_xaxis_transform(), 
           horizontalalignment='center', verticalalignment='center', 
           zorder=50, fontdict={'size': 22})
    ax.text(-2.5, -0.17, s=r'\textbf{Credible interval}', transform=ax.get_xaxis_transform(), 
           horizontalalignment='right', verticalalignment='center', 
           zorder=50, fontdict={'size': 22},
        #    bbox=dict(facecolor='white', boxstyle='round,pad=0.12', linewidth=0, alpha=0.5)
           )
    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (-2.5, -0.17)
    arrow_end_xy = (credible_min, -0.18)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.1",  # negative for downward curve
        transform=ax.get_xaxis_transform(),
        zorder=60,
        clip_on=False  # Ensure the arrow is not clipped
    )
    # ax.set_clip_on(False)  # Ensure the arrow is not clipped
    ax.add_patch(curved_arrow)

    # Set axis properties
    ax.set_ylabel('Plausibility', size=22)
    ax.set_xlabel(r'$\theta$', size=24, zorder=16)
    ax.set_title(r'\textbf{Parameter space}', fontsize=22)
    # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', 
    #        zorder=100, transform=ax.transAxes, 
    #        bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.12', linewidth=0))
    
    ax.set_xlim(-7, 7)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)
    ax.spines['top'].set_visible(False)  
    ax.spines['right'].set_visible(False)  
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['bottom'].set_zorder(40)
    
    # Axis arrows
    ax.plot(7, 0, ">k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10, zorder=101) 
    ax.plot(-7, 1, "^k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10) 

    # Create legend elements
    hpd_patch = mpatches.Patch(label='HPD Sets')
    truth_handle = Line2D([], [], color=truth_color, linestyle="--", linewidth=4, 
                         marker="*", markersize=20, markerfacecolor="white", 
                         markeredgecolor="crimson", markeredgewidth=2, 
                         label=r"True $\theta^{\star}$")

    # Add custom legend entries
    handles, labels = ax.get_legend_handles_labels()
    handles = handles[:-1] + [hpd_patch, truth_handle]
    labels = labels[:-1] + ['HPD Sets', r"True $\theta^{\star}$"]
    
    # ax.legend(
    #     handles=handles, labels=labels,
    #     handler_map={hpd_patch: DividedPatchHandler(
    #         edgecolors=[posterior_color, 'hotpink'], 
    #         facecolors=[(0.824, 0.702, 0.839, 1.0), (0.973, 0.827, 0.906, 1.0)], 
    #         num_patches=2)}, 
    #     loc='lower left', prop={'size': 22}, framealpha=0.9, handlelength=2,
    # ).set_zorder(31)

    if axis is None:
        # plt.savefig('schema_posterior_fixed.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax


In [None]:
from matplotlib.patches import FancyArrowPatch
import matplotlib.gridspec as gridspec

fig = plt.figure(figsize=(12, 5))
gs = gridspec.GridSpec(2, 2, width_ratios=[1.8, 1], height_ratios=[0.6, 1], wspace=0.6, hspace=0.4)

# Large left subplot
ax_left = fig.add_subplot(gs[:, 0])
plot_B_left(axis=ax_left)

# Two small right subplots stacked vertically
ax_right_top = fig.add_subplot(gs[0, 1])
ax_right_bottom = fig.add_subplot(gs[1, 1])

# # You can call plot_A_right or any other plotting function for these
plot_B_topright(axis=ax_right_top, obs_theta=obs_theta)
# # For demonstration, you can plot_A_right again or another function
plot_B_bottomright(axis=ax_right_bottom)

plt.suptitle(r'\textbf{Inverse problem}', fontsize=36, y=1.08)
# Draw a black horizontal arrow pointing to the top right subplot with annotation "Nature"

# Get the bounding box of the top right subplot in figure coordinates
bbox = ax_right_top.get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox.x0 - 0.18, bbox.y0 + bbox.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox.x0 - 0.02, bbox.y0 + bbox.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

# Add annotation "Nature" above the arrow
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2 - 0.02,
    arrow_start[1] + 0.02,  # Slightly above the arrow
    "Nature",
    ha='center', va='bottom',
    fontsize=25,
    color='black',
    fontweight='bold'
)

# Get the bounding box of the top right subplot in figure coordinates
bbox = ax_right_bottom.get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox.x0 - 0.18, bbox.y0 + bbox.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox.x0 - 0.02, bbox.y0 + bbox.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

# Add annotation "Nature" above the arrow
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2 - 0.02,
    arrow_start[1] + 0.02,  # Slightly above the arrow
    "Generative AI",
    ha='center', va='bottom',
    fontsize=25,
    color='black',
    fontweight='bold'
)

plt.savefig('panel_b.pdf', bbox_inches='tight', dpi=300)
plt.savefig('panel_b.png', bbox_inches='tight', dpi=300)
plt.show()

### PANEL C

In [None]:
def plot_C_left(axis, obs_theta):
    is_theta_A = (obs_theta.abs().item() > 1)
    true_param_label = r'\textbf{True parameter $\theta^*_A$}' if is_theta_A else r'\textbf{True parameter $\theta^*_B$}'

    PRIOR_FOR_PANEL_C = Normal(loc=torch.Tensor([0]), scale=torch.Tensor([2.5]))
    prior_samples_for_panel_c = PRIOR_FOR_PANEL_C.sample(sample_shape=(100_000, ))

    # Inset histogram of prior_samples in red
    ax_inset = axis # inset_axes(axis, width="40%", height="40%", loc='upper left')
    ax_inset.set_xlim(-8, 8)
    sns.kdeplot(prior_samples_for_panel_c.numpy().flatten(), ax=ax_inset, color='lightgray', linestyle='--', linewidth=2)
    ax_inset.axvline(x=obs_theta.item(), ymin=0, ymax=1,
                color='crimson' if is_theta_A else 'blue', linestyle='--', linewidth=4, zorder=30)
    ax_inset.scatter(obs_theta.item(), 0,
                marker='*', facecolor='white', edgecolor='crimson' if is_theta_A else 'blue', s=300, linewidth=2, zorder=50, clip_on=False)
    ax_inset.set_ylabel('Plausibility', fontsize=18)
    
    ax_inset.set_xticks([])
    ax_inset.set_yticks([])
    if obs_theta.abs().item() > 1:
        ax_inset.set_title(r'\textbf{Misaligned prior}', fontsize=18)
    else:
        ax_inset.set_title(r'\textbf{Well-aligned prior}', fontsize=18)
    ax_inset.text(-4.0, 0.1, r'Training prior', size=18, ha='center', va='center', zorder=50,
        bbox=dict(color='white', alpha=0.5, boxstyle='round,pad=0.2', linewidth=3))
    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (-4.4, 0.12)
    arrow_end_xy = (-1.5, 0.14)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=-0.3",  # negative for downward curve
        # transform=ax_inset.get_xaxis_transform(),
        zorder=60
    )
    ax_inset.add_patch(curved_arrow)

    ax_inset.text(-7 if obs_theta.item() < 0 else 7, 0.2, s=true_param_label,
            transform=ax_inset.get_xaxis_transform(), 
            horizontalalignment='left' if obs_theta.item() < 0 else 'right', verticalalignment='center', 
            zorder=50, fontdict={'size': 18},
            bbox=dict(color='white', alpha=0.5, boxstyle='round,pad=0.2', linewidth=3))
    # Coordinates: from text (-2.0, 0.73 in axis coords) to KDE peak (use prior_samples.mean())
    arrow_start_xy = (obs_theta.item() - 3.4, 0.13)
    arrow_end_xy = (obs_theta.item(), 0.03)
    # Use the same start and end as before, but add a curve with connectionstyle
    curved_arrow = FancyArrowPatch(
        posA=arrow_start_xy, 
        posB=arrow_end_xy,
        arrowstyle="->",
        color="k",
        linewidth=2,
        mutation_scale=25,
        connectionstyle="arc3,rad=0.1",  # negative for downward curve
        transform=ax_inset.get_xaxis_transform(),
        zorder=60
    )
    ax_inset.add_patch(curved_arrow)
    ax_inset.set_xlabel(r'$\theta$', fontsize=20, labelpad=10)
    ax_inset.spines['top'].set_visible(False)  
    ax_inset.spines['right'].set_visible(False)  
    ax_inset.spines['left'].set_linewidth(2)
    ax_inset.spines['bottom'].set_linewidth(2)
    ax_inset.plot(8, 0, ">k", transform=ax_inset.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101) 
    ax_inset.plot(-8, 1, "^k", transform=ax_inset.get_xaxis_transform(), 
            clip_on=False, markersize=10)

In [None]:
def plot_C_middle(axis=None, obs_theta=None, labels=False):
    PRIOR_FOR_PANEL_C = Normal(loc=torch.Tensor([0]), scale=torch.Tensor([2.5]))
    prior_samples_for_panel_c = PRIOR_FOR_PANEL_C.sample(sample_shape=(100_000, ))
    is_theta_A = (obs_theta.abs().item() > 1)
    likelihood_label = r'\textbf{Likelihood for true $\theta^*_A$}' if is_theta_A else r'\textbf{Likelihood for true $\theta^*_B$}'

    obs_samples = simulator.likelihood(obs_theta).sample(sample_shape=(100_000, ))
    obs_x_samples = obs_theta + torch.tensor([-3, -1.0, 0.25, 1.0, 1.75])

    # Plot histogram of obs_samples in goldenrod
    axis.hist(obs_samples.numpy().flatten(), bins=20, facecolor='goldenrod', edgecolor='brown', linewidth=2, alpha=0.5, label='obs_samples',)
    axis.scatter(obs_x_samples.numpy()[:5], np.zeros_like(obs_x_samples.numpy()[:5]),
                marker='X', facecolor='white', edgecolor='goldenrod', s=300, linewidth=2, zorder=50, clip_on=False)
    axis.set_xlabel(r'$X$', fontsize=20, labelpad=10)
    for idx, x_val in enumerate(obs_x_samples.numpy()[:5]):
        axis.annotate(
            f'{str(idx + 1)}',
            (x_val, 0), xytext=(0, 18), textcoords='offset points',
            ha='center', va='bottom', fontsize=16, color='k', fontweight='bold', zorder=50
        )
        axis.axvline(x=x_val, ymin=0, ymax=1, 
              color='goldenrod', 
              linestyle='--', linewidth=4, zorder=30, alpha=1.0)
    axis.set_ylabel('Probability density', fontsize=18)
    axis.set_title(r'\textbf{Observable data}', fontsize=18)
    axis.text(-6 if obs_theta.item() < 0 else 6, 0.73, s=likelihood_label,
            transform=axis.get_xaxis_transform(), 
            horizontalalignment='left' if obs_theta.item() < 0 else 'right', verticalalignment='center', 
            zorder=50, fontdict={'size': 18},
            bbox=dict(color='white', alpha=0.5, boxstyle='round,pad=0.2', linewidth=3))
    axis.set_xlim(-7, 7)
    axis.set_xticks([])
    axis.set_yticks([])
    axis.spines['top'].set_visible(False)  
    axis.spines['right'].set_visible(False)  
    axis.spines['left'].set_linewidth(2)
    axis.spines['bottom'].set_linewidth(2)
    axis.spines['bottom'].set_zorder(40)
    axis.plot(7, 0, ">k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101) 
    axis.plot(-7, 1, "^k", transform=axis.get_xaxis_transform(), 
            clip_on=False, markersize=10)

In [None]:
def plot_C_right(axis, obs_theta, message=''):
    obs_x_samples = obs_theta + torch.tensor([-3, -1.0, 0.25, 1.0, 1.75])
    is_theta_A = (obs_theta.abs().item() > 1)
    true_param_label = r'\textbf{True parameter $\theta^*_A$}' if is_theta_A else r'\textbf{True parameter $\theta^*_B$}'

    # Right subplot
    credible_sets = []
    for x in obs_x_samples:
        credible_set = hpd_region(
                posterior=posterior,
                param_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
                x=x.reshape(-1, ),
                credible_level=0.9,
                num_level_sets=10_000,
                norm_posterior=NORM_POSTERIOR_SAMPLES
        )[1].numpy()
        credible_sets.append(credible_set)
    if message != '':
        # axis.set_title(r'\textbf{Expected coverage $=90$\%}' + '\n' + r'\textbf{Actual coverage $\approx 80$\%}', fontsize=16, ha='center')
        axis.set_title(r'\textbf{Coverage for $\theta^*\approx 20$\%}', fontsize=18, ha='center')
    else:
        # axis.set_title(r'\textbf{Expected coverage $=90$\%}' + '\n' + r'\textbf{Actual coverage $\approx 20$\%}', fontsize=16, ha='center')
        axis.set_title(r'\textbf{Coverage for $\theta^*\approx 80$\%}', fontsize=18, ha='center')
    axis.set_xticks([])
    axis.set_yticks([])
    axis.spines['top'].set_visible(False)
    axis.spines['right'].set_visible(False)
    axis.spines['left'].set_linewidth(2)
    axis.spines['bottom'].set_linewidth(2)

    # Truth line and marker
    axis.axvline(x=obs_theta.item(), ymin=0, ymax=1, 
            color='crimson' if is_theta_A else 'blue', linestyle='--', linewidth=4, zorder=30)
    axis.scatter(obs_theta.item(), 0, 
            transform=axis.get_xaxis_transform(), marker='*', 
            facecolor='white', edgecolor='crimson' if is_theta_A else 'blue', s=300, linewidth=2, 
            zorder=50, clip_on=False)
    axis.set_xlim(-8, 8)
    axis.set_ylim(0, 0.25 + 5*0.07)
    axis.text(7 if obs_theta.item() < 0 else -7, 0.92, s='Credible\nintervals',
            transform=axis.get_xaxis_transform(), 
            horizontalalignment='right' if obs_theta.item() < 0 else 'left', verticalalignment='top', 
            zorder=50, fontdict={'size': 18},
            bbox=dict(color='white', alpha=0.5, boxstyle='round,pad=0.2', linewidth=3))
    axis.text(-7 if obs_theta.item() < 0 else 7, 0.2, s=true_param_label,
            transform=axis.get_xaxis_transform(), 
            horizontalalignment='left' if obs_theta.item() < 0 else 'right', verticalalignment='center', 
            zorder=50, fontdict={'size': 18},
            bbox=dict(color='white', alpha=0.5, boxstyle='round,pad=0.2', linewidth=3))

    # Draw non-overlapping credible interval boxes up the y axis
    num_pts = len(obs_x_samples)
    for pt_idx, obs_x in enumerate(obs_x_samples):
        credible_set = credible_sets[pt_idx]
        credible_min = credible_set.min()
        credible_max = credible_set.max()
        y_bottom = 0.25 + pt_idx * 0.07
        y_top = y_bottom + 0.03
        rect = plt.Rectangle(
                (credible_min, y_bottom),
                credible_max - credible_min,
                0.04,
                linewidth=2,
                edgecolor='purple', # if obs_theta - 0.5 > credible_max else 'red',
                facecolor='purple',
                alpha=0.7 if credible_min <= obs_theta <= credible_max else 0.1,
                zorder=40,
                clip_on=False
        )
        axis.add_patch(rect)
        axis.text(credible_min - 1.0, y_bottom + 0.02, f'{str(pt_idx + 1)}',
                horizontalalignment='center', verticalalignment='center', 
                zorder=50, fontdict={'size': 16})
        # axis.plot([credible_min, credible_max], [y_bottom, y_bottom], 
        #         transform=axis.get_xaxis_transform(), color='purple', linewidth=4, 
        #         zorder=50, clip_on=False)

    axis.plot(8, 0, ">k", transform=axis.get_xaxis_transform(), 
                clip_on=False, markersize=10, zorder=101) 
    axis.set_xlabel(r'$\theta$', size=20, labelpad=10)

In [None]:
from matplotlib.patches import FancyArrowPatch

fig, axs = plt.subplots(2, 3, figsize=(16, 5))
plt.subplots_adjust(wspace=0.6, hspace=0.5)

plot_C_left(axis=axs[0, 0], obs_theta=torch.tensor([3.5]))
plot_C_middle(axis=axs[0, 1], obs_theta=torch.tensor([3.5]), labels=True)
plot_C_right(axis=axs[0, 2], obs_theta=torch.tensor([3.5]), message='test')

plot_C_left(axis=axs[1, 0], obs_theta=torch.tensor([-0.5]))
plot_C_middle(axis=axs[1, 1], obs_theta=torch.tensor([-0.5]))
plot_C_right(axis=axs[1, 2], obs_theta=torch.tensor([-0.5]), message='')

for row in range(2):
    for col in range(2):
        # Get the right edge of the current axis and the left edge of the next axis in the same row
        ax_from = axs[row, col]
        ax_to = axs[row, col + 1]
        # Get positions in figure coordinates
        bbox_from = ax_from.get_position(fig)
        bbox_to = ax_to.get_position(fig)
        # Arrow start: right center of ax_from
        start = (bbox_from.x1, bbox_from.y0 + bbox_from.height / 2)
        # Arrow end: left center of ax_to
        end = (bbox_to.x0 - 0.02, bbox_to.y0 + bbox_to.height / 2)
        arrow = FancyArrowPatch(
            start, end,
            transform=fig.transFigure,
            arrowstyle='-|>',
            mutation_scale=30,
            color='black',
            linewidth=2,
            zorder=200,
            connectionstyle="arc3,rad=0"
        )
        fig.patches.append(arrow)

plt.suptitle(r'\textbf{Local coverage for a fixed parameter}', fontsize=28, y=1.05)

plt.savefig('panel_c.pdf', bbox_inches='tight', dpi=300)
plt.savefig('panel_c.png', bbox_inches='tight', dpi=300)
plt.show()

### PANEL D

In [None]:
def plot_D_middle(axis, credible_sets_list, obs_theta, which_level=0):
    """
    Plot the right panel showing non-overlapping credible interval boxes up the y axis.
    Arguments:
        axis: matplotlib axis to plot on
        credible_sets_list: list of credible sets for each observation
        obs_theta: true theta value (torch.Tensor or float)
    """
    ax = axis
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)

    # Truth line and marker
    theta_val = obs_theta.item() if hasattr(obs_theta, "item") else float(obs_theta)
    ax.axvline(x=theta_val, ymin=0, ymax=1, color='crimson', linestyle='--', linewidth=4, zorder=30)
    ax.scatter(theta_val, 0, transform=ax.get_xaxis_transform(), marker='*', 
               facecolor='white', edgecolor='white', s=300, linewidth=2, 
               zorder=50, clip_on=False)
    ax.scatter(theta_val, 0, transform=ax.get_xaxis_transform(), marker='*', 
               facecolor='none', edgecolor='crimson', s=300, linewidth=2, 
               zorder=50, clip_on=False)
    ax.set_xlim(-10, 10)
    ax.set_ylim(0, 1.0)

    # Draw non-overlapping credible interval boxes up the y axis
    num_pts = len(credible_sets_list)
    rect_height = 0.6 / (num_pts + 1)
    y_base = 0.02
    y_step = rect_height * 1.8
    # Stagger the intervals so that intervals for the same point are close vertically,
    # and intervals for different points are farther apart.
    for pt_idx in range(5):
        for i, cs in enumerate(credible_sets_list[pt_idx]):
            if i == which_level:
                # Get the credible interval for the current point
                credible_min = cs.min()
                credible_max = cs.max()
                # Stagger: for each pt_idx, intervals are grouped together, separated by a larger gap from other pt_idx
                # Here, intervals for the same pt_idx are stacked with a small offset, and groups are separated by y_step
                y_bottom = y_base + pt_idx * y_step * 2 + i * rect_height * 1.2

                # Highlight if obs_theta is inside the interval
                theta_val = obs_theta.item() if hasattr(obs_theta, "item") else float(obs_theta)
                alpha_val = 0.7

                if obs_theta - 0.5 < credible_max:
                    edge_color = 'red'
                elif i == 0:
                    edge_color = 'purple'
                else:
                    edge_color = 'hotpink'
                rect = plt.Rectangle(
                    (credible_min, y_bottom),
                    credible_max - credible_min,
                    rect_height,
                    linewidth=2,
                    edgecolor=edge_color,
                    facecolor='purple' if i == 0 else 'hotpink',
                    alpha=alpha_val,
                    hatch='' if obs_theta - 0.5 > credible_max else 'x',
                    zorder=40,
                    clip_on=False
                )
                ax.add_patch(rect)

    ax.plot(8, 0, ">k", transform=ax.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101)

In [None]:
def plot_D_left(axis=None):
    """
    Modified version of plot1 ensuring all elements properly start from 
    the bottom of the subplot when ymin=0 is specified.
    """
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    from matplotlib.lines import Line2D
    import seaborn as sns
    import numpy as np
    import torch
    
    # Since we don't have the original data and custom classes, we'll add placeholders
    # In a real implementation, replace these with your actual data
    np.random.seed(42)
    
    # Calculate credible sets for this example
    credible_sets = []
    for cl in CONFIDENCE_LEVEL:
        lower = np.percentile(posterior_samples, (1-cl)*100/2)
        upper = np.percentile(posterior_samples, 100 - (1-cl)*100/2)
        credible_sets.append(np.array([lower, upper]))
    
    # Create custom handler for legend patches
    class DividedPatchHandler:
        def __init__(self, edgecolors, facecolors, num_patches=2):
            self.edgecolors = edgecolors
            self.facecolors = facecolors
            self.num_patches = num_patches
            
        def legend_artist(self, legend, orig_handle, fontsize, handlebox):
            x0, y0 = handlebox.xdescent, handlebox.ydescent
            width, height = handlebox.width, handlebox.height
            patch_width = width / self.num_patches
            
            patches = []
            for i in range(self.num_patches):
                patch = mpatches.Rectangle(
                    [x0 + i * patch_width, y0], 
                    patch_width, height, 
                    facecolor=self.facecolors[i],
                    edgecolor=self.edgecolors[i],
                    transform=handlebox.get_transform()
                )
                handlebox.add_artist(patch)
                patches.append(patch)
            
            return patches
    
    # Setup plot
    plt.rc('text', usetex=True)  # Enable LaTeX
    plt.rc('font', family='serif')  # Use a serif font
    plt.rcParams['text.latex.preamble'] = r'''
        \usepackage{amsmath}  % For \mathbb
        \usepackage{amssymb}  % For \mathbb
        \usepackage{bm}       % For bold math symbols
        \usepackage{underscore} % If underscores are needed
    '''

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax = fig.gca()
    else:
        ax = axis
        
    posterior_color = 'purple'
    prior_color = 'lightgrey'
    truth_color = 'crimson'

    # CRITICAL FIX: Set the default transform explicitly for data coordinates
    # This ensures that when ymin=0 is specified, it refers to the data coordinate system
    data_transform = ax.transData
    
    # Plot densities with extended range
    # Use fill_between to ensure KDE plots start from zero
    x_range = np.linspace(-7, 7, 1000)
    
    # # For prior KDE - ensure it extends to the axis
    sns.kdeplot(prior_samples.squeeze(), ax=ax, color=prior_color, linewidth=2, zorder=10, 
                linestyle='--', clip_on=True)
    
    # For posterior KDE
    kde_posterior = sns.kdeplot(posterior_samples.squeeze(), color='black', ax=ax, linewidth=2, 
                               zorder=30, clip_on=True)
    
    # Get posterior curve data for HPD calculations
    posterior_line = kde_posterior.get_lines()[-1]
    x_posterior = posterior_line.get_xdata()
    y_posterior = posterior_line.get_ydata()
    
    # Add minimum density points to complete the KDEs at the bottom
    # This ensures all KDE curves visually start at y=0
    ax.plot([-7, x_posterior[0]], [0, y_posterior[0]], color='black', linewidth=2, zorder=30)
    ax.plot([x_posterior[-1], 7], [y_posterior[-1], 0], color='black', linewidth=2, zorder=30)
    
    # Ensure the y-axis starts at zero for proper coordinate mapping
    ymin_data = 0
    ax.set_ylim(bottom=ymin_data)
    
    # Create HPD sets visualization
    colors = [posterior_color, 'hotpink']
    # pos_rectangle_offset, neg_rectangle_offset = 0.02, -0.02
    
    for i, cl in enumerate(CONFIDENCE_LEVEL):
        pos_rectangle_offset, neg_rectangle_offset = (-0.15, -0.2) # if i == 1 else (-0.19, -0.24) # (-0.04, -0.09) if i == 1 else (-0.14, -0.19)

        credible_min = credible_sets[i].min()
        credible_max = credible_sets[i].max()
        intersect_point = np.interp([credible_min, credible_max], x_posterior, y_posterior)

        # HPD set indicators (still using axis transform as this is intentional for placement)
        ax.plot([credible_min, credible_max], [pos_rectangle_offset]*2, 
                transform=ax.get_xaxis_transform(), color=colors[i], linewidth=4, 
                zorder=50, clip_on=False)
        ax.plot([credible_min, credible_max], [neg_rectangle_offset]*2, 
                transform=ax.get_xaxis_transform(), color=colors[i], linewidth=4, 
                zorder=50, clip_on=False)
        ax.vlines([credible_min, credible_max], ymin=[neg_rectangle_offset]*2, 
                  ymax=[pos_rectangle_offset]*2, transform=ax.get_xaxis_transform(), 
                  color=colors[i], linewidth=4, zorder=50, clip_on=False)
        ax.fill_between([credible_min, credible_max], neg_rectangle_offset, pos_rectangle_offset, 
                       transform=ax.get_xaxis_transform(), color=colors[i], alpha=0.3, 
                       zorder=50, clip_on=False)
        
        if i == 0:
            ax.fill_between([credible_min, credible_max], neg_rectangle_offset, pos_rectangle_offset, 
                           transform=ax.get_xaxis_transform(), color="white", alpha=1.0, 
                           zorder=49, clip_on=False)
            
        # CRITICAL FIX: PROJECTION LINES 
        # Use explicit data coordinates to ensure lines start from bottom
        # Important: Don't use transform here to keep in data coordinates
        ax.vlines(x=[credible_min, credible_max], ymin=0.5 * neg_rectangle_offset, ymax=intersect_point, 
                  transform=data_transform, color=colors[i], linestyle=':', 
                 linewidth=3, zorder=29, clip_on=False)
        
        # CRITICAL FIX: SHADING UNDER POSTERIOR FOR HPD
        # Create complete polygon from bottom to curve and back to bottom
        mask = (x_posterior >= credible_min) & (x_posterior <= credible_max)
        x_hpd = x_posterior[mask]
        y_hpd = y_posterior[mask]
        
        # Add points at the base of the intersect point to create a closed polygon
        x_polygon = np.concatenate([[credible_min], x_hpd, [credible_max], [credible_min]])
        y_polygon = np.concatenate([[intersect_point[0]], y_hpd, [intersect_point[1]], [intersect_point[0]]])
        
        # Plot the filled polygon
        ax.fill(x_polygon, y_polygon, color=colors[i], alpha=0.4, 
               zorder=28+i, linestyle='--', clip_on=True)
        
        # Horizontal level-set lines for HPDs
        ax.hlines(y=intersect_point[0], 
                 xmin=((credible_sets[0].min()+credible_sets[0].max())/2)-2.9, 
                 xmax=((credible_sets[0].min()+credible_sets[0].max())/2)+2.9, 
                 linestyle='--', linewidth=2, color='k', zorder=28+i)
        
        # Text labels
        ax.text(((credible_sets[0].min()+credible_sets[0].max())/2)+3.05, intersect_point[0], 
                s=r'$c$',
                horizontalalignment='left', verticalalignment='center', 
                zorder=50, fontdict={'size': 20})
        
        # TEXT FOR CONFIDENCE LEVEL
        ax.text((credible_min+credible_max)/2, intersect_point[0]+(0.07 if i == 1 else 0.12), 
                s=f'{cl*100:.0f}\%', 
                horizontalalignment='center', verticalalignment='center', 
                zorder=50, fontdict={'size': 16})

        ax.text(credible_max + 0.22, 0.5*(neg_rectangle_offset+pos_rectangle_offset),
                s=r'${cl:.0f}\%\text{{ HPD credible interval}}$'.format(cl=100*cl),
                transform=ax.get_xaxis_transform(), horizontalalignment='left', verticalalignment='center', 
                zorder=50, fontdict={'size': 20})

    # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top
    # Explicitly use data coordinates with full height
    ax.axvline(x=float(obs_theta.reshape(1).numpy()), ymin=0, ymax=1, 
              color=truth_color, 
              linestyle='--', linewidth=4, zorder=30)
    
    # Star marker for truth
    ax.scatter(float(obs_theta.reshape(1).numpy()), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='white', edgecolor='white', s=300, linewidth=2, 
              zorder=50, clip_on=False)
    ax.scatter(float(obs_theta.reshape(1).numpy()), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='none', edgecolor=truth_color, s=300, linewidth=2, 
              zorder=50, clip_on=False)

    # Text labeling
    ax.text(-2.7, 0.5, s=r'Training prior', transform=ax.get_xaxis_transform(), 
           horizontalalignment='center', verticalalignment='center', 
           zorder=50, fontdict={'size': 22})
    ax.text(-0.75, 0.92, s=r'\textbf{Posterior}', transform=ax.get_xaxis_transform(), 
           horizontalalignment='center', verticalalignment='center', 
           zorder=50, fontdict={'size': 22})

    # Set axis properties
    ax.set_ylabel('Plausibility', size=25)
    ax.set_xlabel(r'$\theta$', size=27, zorder=20)
    # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', 
    #        zorder=100, transform=ax.transAxes, 
    #        bbox=dict(facecolor='white', edgecolor='black', boxstyle='round,pad=0.12', linewidth=0))
    
    ax.set_xlim(-7, 7)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.grid(False)
    ax.spines['top'].set_visible(False)  
    ax.spines['right'].set_visible(False)  
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['bottom'].set_zorder(40)
    
    # Axis arrows
    ax.plot(7, 0, ">k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10, zorder=101) 
    ax.plot(-7, 1, "^k", transform=ax.get_xaxis_transform(), 
           clip_on=False, markersize=10) 

    # Create legend elements
    hpd_patch = mpatches.Patch(label='HPD Sets')
    truth_handle = Line2D([], [], color=truth_color, linestyle="--", linewidth=4, 
                         marker="*", markersize=20, markerfacecolor="white", 
                         markeredgecolor="crimson", markeredgewidth=2, 
                         label=r"True $\theta^{\star}$")

    # Add custom legend entries
    handles, labels = ax.get_legend_handles_labels()
    handles = handles[:-1] + [hpd_patch, truth_handle]
    labels = labels[:-1] + ['HPD Sets', r"True $\theta^{\star}$"]
    
    # ax.legend(
    #     handles=handles, labels=labels,
    #     handler_map={hpd_patch: DividedPatchHandler(
    #         edgecolors=[posterior_color, 'hotpink'], 
    #         facecolors=[(0.824, 0.702, 0.839, 1.0), (0.973, 0.827, 0.906, 1.0)], 
    #         num_patches=2)}, 
    #     loc='lower left', prop={'size': 22}, framealpha=0.9, handlelength=2,
    # ).set_zorder(31)

    if axis is None:
        # plt.savefig('schema_posterior_fixed.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax



    # plt.rc('text', usetex=True)  # Enable LaTeX
    # plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
    # plt.rcParams['text.latex.preamble'] = r'''
    #     \usepackage{amsmath}  % For \mathbb
    #     \usepackage{amssymb}  % For \mathbb
    #     \usepackage{bm}       % For bold math symbols
    #     \usepackage{underscore} % If underscores are needed
    # '''

    # if axis is None:
    #     fig = plt.figure(figsize=(8.5, 5))
    #     ax = plt.gca()
    # else:
    #     ax = axis
    # # sns.set_style('white')
    # colors = ['purple', 'hotpink']
    # linestyles = ['--', ':']

    # for i, cl in enumerate([0.95, 0.683]):
    #     df_plot = pd.DataFrame({
    #         "parameters": diagn_objects_cred[cl][1].reshape(-1,),
    #         "mean_proba": diagn_objects_cred[cl][2].reshape(-1,)*100,
    #         "lower_proba": diagn_objects_cred[cl][-1].reshape(-1,)*100,
    #         "upper_proba": diagn_objects_cred[cl][-2].reshape(-1,)*100
    #     }).sort_values(by="parameters")

    #     #ax.plot(df_plot.parameters, df_plot.mean_proba, color=lf2i_color, linewidth=3)#, label='Estimated Coverage')
    #     ax.plot(df_plot.parameters, df_plot.lower_proba, color=colors[i], linewidth=3)
    #     ax.plot(df_plot.parameters, df_plot.upper_proba, color=colors[i], linewidth=3)
    #     ax.fill_between(x=df_plot.parameters, y1=df_plot.lower_proba, y2=df_plot.upper_proba, alpha=0.2, color=colors[i])
    #     ax.axhline(y=cl*100, color='black', linestyle=linestyles[i], linewidth=2, zorder=10)  # label=f"Nominal coverage = {round(100 * CONFIDENCE_LEVEL, 1)}%"
    #     # ax.text(
    #     #     0.01, cl*100-5, s=f'{cl*100:.0f}\%', 
    #     #     transform=ax.get_yaxis_transform(), horizontalalignment='left', verticalalignment='center', zorder=50, fontdict={'size': 20}
    #     # )

    # # # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top
    # # # Explicitly use data coordinates with full height
    # # ax.axvline(x=obs_theta.item(), ymin=0, ymax=1, 
    # #           color=truth_color, 
    # #           linestyle='-', linewidth=4, zorder=30)

    # # # Draw arrows suggesting movement of the vertical line (truth) left and right
    # # arrow_props = dict(arrowstyle="<|-|>", color="crimson", linewidth=4, mutation_scale=18, linestyle='-')

    # # # Calculate a small offset for the arrows
    # # arrow_offset = 3  # adjust as needed for visual clarity
    # # truth_x = obs_theta.item()
    # # arrow_y = 0.92  # relative axis height
    # # ax.annotate(
    # #     '', 
    # #     xy=(truth_x - arrow_offset, arrow_y), 
    # #     xytext=(truth_x + arrow_offset, arrow_y),
    # #     xycoords=('data', 'axes fraction'),
    # #     textcoords=('data', 'axes fraction'),
    # #     arrowprops=arrow_props,
    # #     zorder=100
    # # )
    
    # # Star marker for truth
    # ax.scatter(obs_theta.item(), 0, 
    #           transform=ax.get_xaxis_transform(), marker='*', 
    #           facecolor='white', edgecolor='white', s=300, linewidth=2, 
    #           zorder=50, clip_on=False)
    # ax.scatter(obs_theta.item(), 0, 
    #           transform=ax.get_xaxis_transform(), marker='*', 
    #           facecolor='none', edgecolor='crimson', s=300, linewidth=2, 
    #           zorder=50, clip_on=False)

    # ax.set_xlabel(r"$\theta$", fontsize=25)
    # # ax.set_xlabel(r"$\theta$", fontsize=25, loc='right', labelpad=-27)
    # # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', zorder=100, transform=ax.transAxes, bbox=dict(facecolor='white', edgecolor='black',  boxstyle='round,pad=0.15',  linewidth=0))
    # ax.set_ylabel("Coverage probability [\%]", fontsize=25, labelpad=-30)
    # ax.set_xlim(-10.5, 10.5)
    # ax.set_ylim(0, 102)
    # ax.set_xticks([])
    # ax.set_xticklabels([])
    # # ax.set_yticks([0, 100])
    # # ax.set_yticklabels(['0', '100'], size=20)
    # ax.spines['top'].set_visible(False)
    # ax.spines['right'].set_visible(False)
    # ax.spines['bottom'].set_linewidth(2)
    # ax.spines['left'].set_linewidth(2)
    # ax.grid(False)
    # # ax.spines['left'].set_visible(False) 
    # ax.plot(10.5, 0, ">k", transform=ax.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 
    # ax.plot(-10.5, 1, "^k", transform=ax.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 

    # #ax.legend(loc='lower left', prop={'size': 22})
    # if axis is None:
    #     # plt.savefig('schema_diagn_hpd.pdf', bbox_inches='tight')
    #     plt.show()
    # else:
    #     return ax

def plot_D_right(axis = None):

    plt.rc('text', usetex=True)  # Enable LaTeX
    plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
    plt.rcParams['text.latex.preamble'] = r'''
        \usepackage{amsmath}  % For \mathbb
        \usepackage{amssymb}  % For \mathbb
        \usepackage{bm}       % For bold math symbols
        \usepackage{underscore} % If underscores are needed
    '''
    truth_color = 'crimson'

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax = plt.gca()
    else:
        ax = axis
    # sns.set_style('white')
    colors = ['purple', 'hotpink']
    linestyles = ['--', ':']

    for i, cl in enumerate(CONFIDENCE_LEVEL):
        df_plot = pd.DataFrame({
            "parameters": diagn_objects_cred[cl][1].reshape(-1,),
            "mean_proba": diagn_objects_cred[cl][2].reshape(-1,)*100,
            "lower_proba": diagn_objects_cred[cl][-1].reshape(-1,)*100,
            "upper_proba": diagn_objects_cred[cl][-2].reshape(-1,)*100
        }).sort_values(by="parameters")

        #ax.plot(df_plot.parameters, df_plot.mean_proba, color=lf2i_color, linewidth=3)#, label='Estimated Coverage')
        ax.plot(df_plot.parameters, df_plot.lower_proba, color=colors[i], linewidth=3)
        ax.plot(df_plot.parameters, df_plot.upper_proba, color=colors[i], linewidth=3)
        ax.fill_between(x=df_plot.parameters, y1=df_plot.lower_proba, y2=df_plot.upper_proba, alpha=0.2, color=colors[i])
        ax.axhline(y=cl*100, color='black', linestyle=linestyles[i], linewidth=2, zorder=10)  # label=f"Nominal coverage = {round(100 * CONFIDENCE_LEVEL, 1)}%"
        ax.text(
            0.99, cl*100+5, s=f'Expected coverage={cl*100:.0f}\%', 
            transform=ax.get_yaxis_transform(), horizontalalignment='right', verticalalignment='center', zorder=50, fontdict={'size': 20},
            bbox=dict(facecolor='white', boxstyle='round,pad=0.05', alpha=0.5, linewidth=0)
        )
        # ax.text(
        #     0.01, cl*100-5, s=f'{cl*100:.0f}\%', 
        #     transform=ax.get_yaxis_transform(), horizontalalignment='left', verticalalignment='center', zorder=50, fontdict={'size': 20}
        # )

    ax.axvline(x=float(obs_theta.reshape(1).numpy()), ymin=0, ymax=1, 
              color=truth_color, 
              linestyle='--', linewidth=4, zorder=30)
    
    # Star marker for truth
    ax.scatter(float(obs_theta.reshape(1).numpy()), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='white', edgecolor='red', s=300, linewidth=2, 
              zorder=50, clip_on=False)

    ax.set_xlabel(r"True parameter $\theta^*$", fontsize=25)
    # ax.set_xlabel(r"$\theta$", fontsize=25, loc='right', labelpad=-27)
    # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', zorder=100, transform=ax.transAxes, bbox=dict(facecolor='white', edgecolor='black',  boxstyle='round,pad=0.15',  linewidth=0))
    ax.set_ylabel("Actual coverage [\%]", fontsize=25)
    ax.set_xlim(-10.5, 10.5)
    ax.set_ylim(0, 102)
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    # ax.set_yticks([0, 100])
    # ax.set_yticklabels(['0', '100'], size=20)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.grid(False)
    # ax.spines['left'].set_visible(False) 
    ax.plot(10.5, 0, ">k", transform=ax.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 
    ax.plot(-10.5, 1, "^k", transform=ax.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 

    #ax.legend(loc='lower left', prop={'size': 22})
    if axis is None:
        # plt.savefig('schema_diagn_hpd.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import FancyArrowPatch

fig, axs = plt.subplots(1, 2, figsize=(16, 4), width_ratios=[1, 1])
plt.subplots_adjust(wspace=0.7)

plot_D_left(axis=axs[0])
axs[0].set_title(r'\textbf{From HPD to FreB intervals}', fontsize=25, pad=10)

plot_D_right(axis=axs[1])
axs[1].set_title(r'\textbf{Local diagnostics for every true $\theta^*$}', size=25, pad=10)

# Get the bounding box of the top right subplot in figure coordinates
bbox0 = axs[0].get_position(fig)
bbox1 = axs[1].get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox0.x1, bbox0.y0 + bbox0.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox1.x0 - 0.02, bbox1.y0 + bbox1.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

# Add annotation "Nature" above the arrow
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2 - 0.01,
    arrow_start[1] + 0.02,  # Slightly above the arrow
    r"For \textit{every} true parameter",
    ha='center', va='bottom',
    fontsize=18,
    color='black',
    fontweight='bold'
)
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2 - 0.01,
    arrow_start[1] - 0.02,  # Slightly below the arrow
    r"and \textit{all} observable data",
    ha='center', va='top',
    fontsize=18,
    color='black',
    fontweight='bold'
)

# plt.suptitle('1D synthetic example -- performance of credible intervals', fontsize=32)
plt.savefig('panel_d.pdf', bbox_inches='tight')
plt.savefig('panel_d.png', bbox_inches='tight')
plt.show()

### PANEL E

In [None]:
def plot_E_left(axis = None):
    """
    Improved plotting function ensuring all elements properly start from the x-axis,
    with appropriate transformations for in-axis and axis-aligned elements.
    """
    plt.rc('text', usetex=True)  # Enable LaTeX
    plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
    plt.rcParams['text.latex.preamble'] = r'''
        \usepackage{amsmath}  % For \mathbb
        \usepackage{amssymb}  % For \mathbb
        \usepackage{bm}       % For bold math symbols
        \usepackage{underscore} % If underscores are needed
    '''

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax1 = fig.gca()
    else:
        ax1 = axis
    posterior_color = 'purple'
    lf2i_color = 'mediumseagreen'
    calib_color = 'black'
    truth_color = 'crimson'

    # CRITICAL FIX: Set the default transform explicitly for data coordinates
    data_transform = ax1.transData
    
    # Make sure y-axis starts at zero for proper alignment
    ax1.set_ylim(bottom=0)

    # # For prior KDE - ensure it extends to the axis
    # ax1_twin = ax1.twinx()
    # sns.kdeplot(prior_samples.squeeze(), ax=ax1_twin, color='lightgray', linewidth=2, zorder=10, 
    #             linestyle='--', clip_on=True)

    # P-VALUES at X_OBS across POI GRID
    sns.lineplot(
        x=eval_grid.numpy().reshape(-1, ), y=p_values, color=calib_color, linewidth=2, zorder=20, ax=axis, transform=ax1.get_xaxis_transform()
    )

    # separate text for p-value function
    # ax1.text(-1.0, 0.68, s=r'Prior', transform=ax1.get_xaxis_transform(), 
    #        horizontalalignment='center', verticalalignment='center', 
    #        zorder=50, fontdict={'size': 22})
    ax1.text(
            -2.05, 0.93, s=r'\textbf{p-value function}',
            transform=ax1.get_xaxis_transform(), horizontalalignment='center', verticalalignment='center', zorder=50, fontdict={'size': 22}
    )

    # ALPHA LINES and CONFIDENCE SETS
    linestyles = ['--', '--']
    colors = [lf2i_color, 'yellowgreen']
    # pos_rectangle_offset, neg_rectangle_offset = 0.02, -0.02

    for i, cl in enumerate(CONFIDENCE_LEVEL):
        pos_rectangle_offset, neg_rectangle_offset = (-0.15, -0.2) # if i == 1 else (-0.19, -0.24)
        # Horizontal alpha line (using data coordinates)
        ax1.axhline(y=1-cl, xmin=0.4, xmax=0.955, color='black', linewidth=2, linestyle=linestyles[i], zorder=20)
        ax1.text(
            (confidence_sets[i].min()+confidence_sets[i].max())/2 - 0.05, 1-cl+0.05, s=r'$\alpha={alpha:.2f}$'.format(alpha=(1-cl)), 
            transform=ax1.get_xaxis_transform(), horizontalalignment='center', verticalalignment='center', zorder=50, fontdict={'size': 16}
        )
        
        # Fill between plot and alpha level
        sort_idx = np.argsort(eval_grid.numpy().reshape(-1, ))
        ax1.fill_between(
            eval_grid.numpy().reshape(-1, )[sort_idx],
            p_values[sort_idx],
            1-cl,
            where=(p_values[sort_idx] >= 1-cl),  # Shade only between the curves
            color=colors[i],
            alpha=0.4,
            zorder=15
        )

        confset_min = confidence_sets[i].min()
        confset_max = confidence_sets[i].max()
        
        # CRITICAL FIX: Ensure vertical lines start exactly at y=0 in data coordinates
        # These are the dotted lines extending from the x-axis to the alpha level
        ax1.vlines(
            x=[confset_min, confset_max], 
            ymin=0.8 * neg_rectangle_offset,  # Start at the exact bottom of the plot (data coordinates)
            ymax=1-cl,  # End at the alpha level (data coordinates)
            color=colors[i], 
            linestyle=':', 
            linewidth=3, 
            zorder=10,
            transform=data_transform,  # Explicitly use data coordinates
            clip_on=False
        )
        
        # Rectangle indicator on x-axis (intentionally using axis transform)
        ax1.plot([confset_min, confset_max], [pos_rectangle_offset]*2, 
                transform=ax1.get_xaxis_transform(), color=colors[i], linewidth=4, 
                zorder=103, clip_on=False)
        ax1.plot([confset_min, confset_max], [neg_rectangle_offset]*2, 
                transform=ax1.get_xaxis_transform(), color=colors[i], linewidth=4, 
                zorder=103, clip_on=False)
        ax1.vlines(
            [confset_min, confset_max], 
            ymin=[neg_rectangle_offset]*2, 
            ymax=[pos_rectangle_offset]*2,
            transform=ax1.get_xaxis_transform(), 
            color=colors[i], 
            linewidth=4, 
            zorder=103, 
            clip_on=False
        )
        ax1.fill_between(  # projection rectangle
            [confset_min, confset_max], neg_rectangle_offset, pos_rectangle_offset, 
            transform=ax1.get_xaxis_transform(),
            color=colors[i], alpha=0.3, zorder=103, clip_on=False
        )
        ax1.fill_between(  # opaque patch underneath to avoid seeing axis etc ...
            [confset_min, confset_max], neg_rectangle_offset, pos_rectangle_offset, 
            transform=ax1.get_xaxis_transform(),
            color="white", alpha=1.0, zorder=102, clip_on=False  # opaque "background"
        )
        ax1.text(confset_max + 0.22, 0.5*(neg_rectangle_offset+pos_rectangle_offset),
                s=r'${cl:.0f}\%\text{{ FreB confidence interval}}$'.format(cl=100*cl),
                transform=ax1.get_xaxis_transform(), horizontalalignment='left', verticalalignment='center', 
                zorder=50, fontdict={'size': 20})

    # CRITICAL FIX: TRUTH LINE - ensure it extends from bottom to top using data coordinates
    true_theta = obs_theta.reshape(1, ).numpy()[0]  # Extract the scalar value
    
    # Vertical line for true theta - use data transform to ensure it starts exactly at bottom
    ax1.axvline(
        x=true_theta, 
        ymin=0,  # Start at bottom of plot
        ymax=0.96,  # End near top of plot 
        color=truth_color, 
        linestyle='--', 
        linewidth=4, 
        zorder=19
    )
    
    # Star marker on x-axis
    ax1.scatter(
        true_theta, 0, 
        transform=ax1.get_xaxis_transform(), 
        marker='*', 
        facecolor='white', 
        edgecolor='white', 
        s=300, 
        linewidth=2, 
        zorder=103, 
        clip_on=False
    )
    ax1.scatter(
        true_theta, 0, 
        transform=ax1.get_xaxis_transform(), 
        marker='*', 
        facecolor='none', 
        edgecolor=truth_color, 
        s=300, 
        linewidth=2, 
        zorder=103, 
        clip_on=False
    )

    # Set axis properties
    ax1.set_xlim(-7, 7)
    ax1.set_ylabel(r'$\text{Level }\alpha$', size=25)
    ax1.set_xlabel(r'$\theta$', size=27, zorder=20)
    # ax1.text(0.9925, 0.06, r'$\theta$', size=25, ha='right', va='center', zorder=100, transform=ax1.transAxes, bbox=dict(facecolor='white', edgecolor='black',  boxstyle='round,pad=0.15',  linewidth=0))
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax1.set_xticklabels([])
    ax1.set_yticklabels([])
    ax1.spines['top'].set_visible(False)
    ax1.spines['right'].set_visible(False)
    ax1.spines['bottom'].set_linewidth(2)
    ax1.spines['bottom'].set_zorder(101)
    ax1.spines['left'].set_linewidth(2) 
    
    # CRITICAL FIX: Arrow markers at axis ends with correct transforms
    ax1.plot(7, 0, ">k", transform=ax1.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 
    ax1.plot(-7, 1, "^k", transform=ax1.get_xaxis_transform(), clip_on=False, markersize=10) 

    # Create a custom patch for the confidence set rectangle and add it to the legend; same for true theta
    # confset_patch = mpatches.Patch(label='FreB Sets')
    truth_handle = Line2D([], [], color=truth_color, linestyle="--", linewidth=4, marker="*", markersize=20, markerfacecolor="white", markeredgecolor="crimson", markeredgewidth=2, label=r"True $\theta^{\star}$")

    # # Add custom legend entries
    # handles, labels = ax1.get_legend_handles_labels()
    # handles += [truth_handle] # [confset_patch, truth_handle]
    # labels += [r"True $\theta^{\star}$"] # ['FreB Sets', r"True $\theta^{\star}$"]
    # ax1.legend(
    #     handles=handles, labels=labels, 
    #     loc='lower left', prop={'size': 22}, framealpha=0.9, # handlelength=2,
    #     # handler_map={confset_patch: DividedPatchHandler(edgecolors=[lf2i_color, 'yellowgreen'], facecolors=[(0.796, 0.906, 0.835, 1.0), (0.890, 0.937, 0.773, 1.0)], num_patches=2)}, 
    # ).set_zorder(21)

    if axis is None:
        # plt.savefig('schema_vsi.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax1

In [None]:
def plot_E_middle(axis, confidence_sets_list, obs_theta, which_level):
    """
    Plot the right panel showing non-overlapping confidence interval boxes up the y axis.
    Arguments:
        axis: matplotlib axis to plot on
        confidence_sets_list: list of confidence sets for each observation
        obs_theta: true theta value (torch.Tensor or float)
    """
    ax = axis
    ax.set_xticks([])
    ax.set_yticks([])
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_linewidth(2)
    ax.spines['bottom'].set_linewidth(2)

    # Truth line and marker
    theta_val = obs_theta.item() if hasattr(obs_theta, "item") else float(obs_theta)
    ax.axvline(x=theta_val, ymin=0, ymax=1, color='crimson', linestyle='--', linewidth=4, zorder=400)
    ax.scatter(theta_val, 0, transform=ax.get_xaxis_transform(), marker='*', 
               facecolor='white', edgecolor='white', s=300, linewidth=2, 
               zorder=500, clip_on=False)
    ax.scatter(theta_val, 0, transform=ax.get_xaxis_transform(), marker='*', 
               facecolor='none', edgecolor='crimson', s=300, linewidth=2, 
               zorder=500, clip_on=False)
    ax.set_xlim(-10, 10)
    ax.set_ylim(0, 1.0)

    # Draw non-overlapping confidence interval boxes up the y axis
    num_pts = len(confidence_sets_list)
    rect_height = 0.6 / (num_pts + 1)
    y_base = 0.02
    y_step = rect_height * 1.8
    # Stagger the intervals so that intervals for the same point are close vertically,
    # and intervals for different points are farther apart.
    for pt_idx in range(5):
        for i, cs in enumerate(confidence_sets_list[pt_idx]):
            if i == which_level:
                confidence_min = cs.min()
                # Compute confidence interval min/max, clipped to axis limits if needed
                xlims = ax.get_xlim()
                confidence_max = cs.max()
                # Optionally clip to axis limits
                confidence_min = max(confidence_min, xlims[0])
                confidence_max = min(confidence_max, xlims[1])
                # Stagger: for each pt_idx, intervals are grouped together, separated by a larger gap from other pt_idx
                # Here, intervals for the same pt_idx are stacked with a small offset, and groups are separated by y_step
                y_bottom = y_base + pt_idx * y_step * 2 + i * rect_height * 1.2

                # Highlight if obs_theta is inside the interval
                theta_val = obs_theta.item() if hasattr(obs_theta, "item") else float(obs_theta)
                alpha_val = 0.7

                if confidence_min <= theta_val <= confidence_max:
                    edge_color = 'red'
                elif i == 0:
                    edge_color = 'mediumseagreen'
                else:
                    edge_color = 'yellowgreen'

                rect = plt.Rectangle(
                    (confidence_min, y_bottom),
                    confidence_max - confidence_min,
                    rect_height,
                    linewidth=2,
                    edgecolor=edge_color,
                    facecolor='mediumseagreen' if i == 0 else 'yellowgreen',
                    alpha=alpha_val,
                    hatch='x' if confidence_min <= theta_val <= confidence_max else '',
                    zorder=40,
                    clip_on=False
                )
                ax.add_patch(rect)

    ax.plot(8, 0, ">k", transform=ax.get_xaxis_transform(), 
            clip_on=False, markersize=10, zorder=101)

In [None]:
def plot_E_right(axis = None):
    plt.rc('text', usetex=True)  # Enable LaTeX
    plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
    plt.rcParams['text.latex.preamble'] = r'''
        \usepackage{amsmath}  % For \mathbb
        \usepackage{amssymb}  % For \mathbb
        \usepackage{bm}       % For bold math symbols
        \usepackage{underscore} % If underscores are needed
    '''

    if axis is None:
        fig = plt.figure(figsize=(8.5, 5))
        ax = fig.gca()
    else:
        ax = axis
    # sns.set_style('white')
    truth_color = 'crimson'
    colors = ['mediumseagreen', 'yellowgreen']
    linestyles = ['--', ':']

    for i, cl in enumerate(CONFIDENCE_LEVEL):
        df_plot = pd.DataFrame({
            "parameters": diagn_objects_conf[cl][1].reshape(-1,),
            "mean_proba": diagn_objects_conf[cl][2].reshape(-1,)*100,
            "lower_proba": diagn_objects_conf[cl][-1].reshape(-1,)*100,
            "upper_proba": diagn_objects_conf[cl][-2].reshape(-1,)*100
        }).sort_values(by="parameters")

        #ax.plot(df_plot.parameters, df_plot.mean_proba, color=lf2i_color, linewidth=3)#, label='Estimated Coverage')
        ax.plot(df_plot.parameters, df_plot.lower_proba, color=colors[i], linewidth=3)
        ax.plot(df_plot.parameters, df_plot.upper_proba, color=colors[i], linewidth=3)
        ax.fill_between(x=df_plot.parameters, y1=df_plot.lower_proba, y2=df_plot.upper_proba, alpha=0.2, color=colors[i])
        ax.axhline(
            y=cl*100, # xmin=np.min(df_plot.loc[df_plot.mean_proba >= cl*100, 'parameters']), xmax=np.max(df_plot.loc[df_plot.mean_proba >= cl*100, 'parameters']), 
            color='black', linestyle=linestyles[i], linewidth=2, zorder=10
        )
        # ax.hlines(
        #     y=cl*100, xmin=-10.5, xmax=np.min(df_plot.loc[df_plot.mean_proba >= cl*100, 'parameters']),
        #     color='firebrick', linestyle=linestyles[i], linewidth=2, zorder=10
        # )
        # ax.hlines(
        #     y=cl*100, xmin=np.max(df_plot.loc[df_plot.mean_proba >= cl*100, 'parameters']), xmax=10.5,
        #     color='firebrick', linestyle=linestyles[i], linewidth=2, zorder=10
        # )
        ax.text(
            0.99, cl*100+5, s=f'Expected coverage={cl*100:.0f}\%', 
            transform=ax.get_yaxis_transform(), horizontalalignment='right', verticalalignment='center', zorder=50, fontdict={'size': 20},
            bbox=dict(facecolor='white', boxstyle='round,pad=0.05', alpha=0.5, linewidth=0)
        )

    ax.axvline(x=float(obs_theta.reshape(1).numpy()), ymin=0, ymax=1, 
              color=truth_color, 
              linestyle='--', linewidth=4, zorder=30)
    
    # Star marker for truth
    ax.scatter(float(obs_theta.reshape(1).numpy()), 0, 
              transform=ax.get_xaxis_transform(), marker='*', 
              facecolor='white', edgecolor='crimson', s=300, linewidth=2, 
              zorder=50, clip_on=False)

    ax.set_xlabel(r"True parameter $\theta^*$", fontsize=25)
    # ax.set_xlabel(r"$\theta$", fontsize=25, loc='right', labelpad=-27)
    # ax.text(0.9915, 0.06, r'$\theta$', size=25, ha='right', va='center', zorder=100, transform=ax.transAxes, bbox=dict(facecolor='white', edgecolor='black',  boxstyle='round,pad=0.15',  linewidth=0))
    ax.set_ylabel("Actual coverage [\%]", fontsize=25)
    ax.set_xlim(-10.5, 10.5)
    ax.set_ylim(0, 102)
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    # ax.set_yticks([0, 100])
    # ax.set_yticklabels(['0', '100'], size=20)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_linewidth(2)
    ax.spines['left'].set_linewidth(2)
    ax.grid(False)
    # ax.spines['left'].set_visible(False) 
    ax.plot(10.5, 0, ">k", transform=ax.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 
    ax.plot(-10.5, 1, "^k", transform=ax.get_xaxis_transform(), clip_on=False, markersize=10, zorder=101) 

    #ax.legend(loc='lower left', prop={'size': 22})
    if axis is None:
        # plt.savefig('schema_diagn_vsi.pdf', bbox_inches='tight')
        plt.show()
    else:
        return ax

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import FancyArrowPatch

fig, axs = plt.subplots(1, 2, figsize=(16, 5), width_ratios=[1, 1])
plt.subplots_adjust(wspace=1.0)

plot_E_left(axis=axs[0])
plot_E_right(axis=axs[1])

# Get the bounding box of the top right subplot in figure coordinates
bbox0 = axs[0].get_position(fig)
bbox1 = axs[1].get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox0.x1, bbox0.y0 + bbox0.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox1.x0 - 0.02, bbox1.y0 + bbox1.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)
# plt.suptitle('Improved performance using FreB', fontsize=32, y=0.5)
plt.savefig('panel_e.pdf', bbox_inches='tight')
plt.savefig('panel_e.png', bbox_inches='tight')
plt.show()

### PANELS D & E

In [None]:
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.patches import FancyArrowPatch
from matplotlib.patches import Rectangle

fig, axs = plt.subplots(2, 2, figsize=(16, 8), width_ratios=[1, 1])
plt.subplots_adjust(wspace=0.7, hspace=0.4)

plot_D_left(axis=axs[0][0])
axs[0][0].set_title(r'\textbf{From HPD to FreB intervals}', fontsize=25, pad=10)

plot_D_right(axis=axs[0][1])
axs[0][1].set_title(r'\textbf{Local diagnostics for every true $\theta^*$}', size=25, pad=10)

arrow = FancyArrowPatch(
    (obs_theta.item() - 4.5, 30),
    (obs_theta.item() + 4.5, 30),
    # transform=ax.get_xaxis_transform(),
    arrowstyle='<|-|>',
    mutation_scale=20,  # Thicker arrow head
    color='red',
    linewidth=3,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
axs[0][1].add_patch(arrow)

# Get the bounding box of the top right subplot in figure coordinates
bbox0 = axs[0][0].get_position(fig)
bbox1 = axs[0][1].get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox0.x1, bbox0.y0 + bbox0.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox1.x0 - 0.02, bbox1.y0 + bbox1.height / 2)

arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

# Add a rectangular patch between the top two subplots
rect_x = arrow_start[0] + 0.005
rect_y = arrow_start[1] - 0.05  # Adjusted to be below the arrow
rect_width = arrow_end[0] - rect_x - 0.007  # Adjusted to fit between the subplots
rect_height = 0.11
rect = Rectangle(
    (rect_x, rect_y),
    rect_width,
    rect_height,
    transform=fig.transFigure,
    linewidth=2,
    edgecolor='k',
    facecolor='white',
    zorder=100
)
fig.patches.append(rect)

# Add annotation "Nature" above the arrow
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2,
    arrow_start[1] + 0.02,  # Slightly above the arrow
    r"For \textbf{every} true parameter",
    ha='center', va='bottom',
    fontsize=18,
    color='black',
    fontweight='bold',
    zorder=201,
    # bbox=dict(facecolor='white', boxstyle='round,pad=0.5', linewidth=3)
)
fig.text(
    (arrow_start[0] + arrow_end[0]) / 2,
    arrow_start[1] - 0.01,  # Slightly below the arrow
    r"and \textbf{all} observable data",
    ha='center', va='top',
    fontsize=18,
    color='black',
    fontweight='bold',
    zorder=201,
    # bbox=dict(facecolor='white', boxstyle='round,pad=0.5', linewidth=3)
)

plot_E_left(axis=axs[1][0])
plot_E_right(axis=axs[1][1])

arrow = FancyArrowPatch(
    (obs_theta.item() - 4.5, 30),
    (obs_theta.item() + 4.5, 30),
    # transform=ax.get_xaxis_transform(),
    arrowstyle='<|-|>',
    mutation_scale=20,  # Thicker arrow head
    color='red',
    linewidth=3,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
axs[1][1].add_patch(arrow)

# Get the bounding box of the top right subplot in figure coordinates
bbox0 = axs[1][0].get_position(fig)
bbox1 = axs[1][1].get_position(fig)
# Arrow start: just to the left of the right subplot
arrow_start = (bbox0.x1, bbox0.y0 + bbox0.height / 2)
# Arrow end: left edge of the right subplot
arrow_end = (bbox1.x0 - 0.02, bbox1.y0 + bbox1.height / 2)
arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='black',
    linewidth=4,        # Thicker line
    zorder=200,
    connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

#
bbox0 = axs[0][0].get_position(fig)
bbox1 = axs[1][0].get_position(fig)
arrow_start = (bbox0.x0 + bbox0.width / 6, bbox0.y0)
arrow_end = (bbox1.x0 + bbox1.width / 6, bbox1.y1 - 0.02)
arrow = FancyArrowPatch(
    arrow_start, arrow_end,
    transform=fig.transFigure,
    arrowstyle='-|>',
    mutation_scale=40,  # Thicker arrow head
    color='red',
    linewidth=4,        # Thicker line
    zorder=20,
    # connectionstyle="arc3,rad=0"
)
fig.patches.append(arrow)

fig.text(arrow_start[0], (arrow_start[1] + arrow_end[1]) / 2 + 0.02, r"{\bf RESHAPE}", fontsize=20, color='black', ha='center', va='center', bbox=dict(facecolor='white', edgecolor='red', boxstyle='round,pad=0.5', linewidth=3), zorder=21)

# plt.suptitle('1D synthetic example -- performance of credible intervals', fontsize=32)
plt.savefig('panels_d_e.pdf', bbox_inches='tight')
plt.savefig('panels_d_e.png', bbox_inches='tight')
plt.show()