In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import tqdm
import dill
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.patches import FancyArrowPatch
import seaborn as sns
import torch
from torch.distributions import MultivariateNormal
from sbi.inference import FMPE, SNPE, NPSE
from sbi.analysis import pairplot
from sbi.utils import BoxUniform
import sbibm
from lf2i.inference import LF2I
from lf2i.test_statistics.posterior import Posterior
from lf2i.test_statistics.waldo import Waldo
from lf2i.calibration.critical_values import train_qr_algorithm
from lf2i.utils.other_methods import hpd_region
from lf2i.plot.parameter_regions import plot_parameter_regions
from lf2i.plot.coverage_diagnostics import coverage_probability_plot
from tsi.temp.utils import kdeplots2D

### Settings

In [None]:
POI_DIM = 2  # parameter of interest
PRIOR_LOC = [0, 0]
PRIOR_VAR = 0.5 # (6*np.sqrt(2.0))**2
POI_BOUNDS = {r'$\theta_1$': (-1, 1), r'$\theta_2$': (-1, 1)}
PRIOR = MultivariateNormal(
    loc=torch.Tensor([0.5, 0.5]), covariance_matrix=PRIOR_VAR*torch.eye(n=POI_DIM)
)
# PRIOR = BoxUniform(
#     low=torch.tensor((POI_BOUNDS[r'$\theta_1$'][0]-1, POI_BOUNDS[r'$\theta_2$'][0]-1)),
#     high=torch.tensor((POI_BOUNDS[r'$\theta_1$'][1]+1, POI_BOUNDS[r'$\theta_2$'][1]+1))
# )

B = 10_000  # num simulations to estimate posterior and test statistics
B_PRIME = 1_000  # num simulations to estimate critical values
B_DOUBLE_PRIME = 1_000  # num simulations to do diagnostics
EVAL_GRID_SIZE = 30_000  # num evaluation points over parameter space to construct confidence sets
CONFIDENCE_LEVEL = 0.954, 0.683  # 0.99

REFERENCE = BoxUniform(
    low=torch.tensor((POI_BOUNDS[r'$\theta_1$'][0]-1, POI_BOUNDS[r'$\theta_2$'][0]-1)),
    high=torch.tensor((POI_BOUNDS[r'$\theta_1$'][1]+1, POI_BOUNDS[r'$\theta_2$'][1]+1))
)
EVAL_GRID_DISTR = BoxUniform(
    low=torch.tensor((POI_BOUNDS[r'$\theta_1$'][0], POI_BOUNDS[r'$\theta_2$'][0])),
    high=torch.tensor((POI_BOUNDS[r'$\theta_1$'][1], POI_BOUNDS[r'$\theta_2$'][1]))
)

POSTERIOR_KWARGS = {
    # 'norm_posterior': None
}
DEVICE = 'cpu'

In [None]:
task = sbibm.get_task('two_moons')
simulator = task.get_simulator()

In [None]:
kdeplots2D(
    [REFERENCE.sample(sample_shape=(100_000, )), PRIOR.sample(sample_shape=(100_000, ))],
    true_theta=None,
    plot_marginals=True,
    xlim=(POI_BOUNDS[r'$\theta_1$'][0], POI_BOUNDS[r'$\theta_2$'][1]), 
    ylim=(POI_BOUNDS[r'$\theta_1$'][0], POI_BOUNDS[r'$\theta_2$'][1]),
    names=['universal', 'prior'],
    axis_labels=[r'$\theta_1$', r'$\theta_2$']
)

### NDE

In [None]:
# try:
#     with open('results/fmpe_strong_prior.pkl', 'rb') as f:
#         fmpe_posterior = dill.load(f)
# except:
#     b_params = PRIOR.sample(sample_shape=(B, ))
#     b_samples = simulator(b_params)
#     b_params.shape, b_samples.shape
#     fmpe = FMPE(
#         prior=PRIOR,
#         device='cpu'
#     )

#     _ = fmpe.append_simulations(b_params, b_samples).train()
#     fmpe_posterior = fmpe.build_posterior()
#     with open('results/fmpe_strong_prior.pkl', 'wb') as f:
#         dill.dump(fmpe_posterior, f)

In [None]:
# try:
#     with open('results/snpe_uniform_prior.pkl', 'rb') as f:
#         snpe_posterior = dill.load(f)
# except:
#     b_params = PRIOR.sample(sample_shape=(B, ))
#     b_samples = simulator(b_params)
#     b_params.shape, b_samples.shape
#     snpe = SNPE(
#         prior=PRIOR,
#         density_estimator='maf',
#         device='cpu'
#     )

#     _ = snpe.append_simulations(b_params, b_samples).train()
#     snpe_posterior = snpe.build_posterior()
#     with open('results/snpe_uniform_prior.pkl', 'wb') as f:
#         dill.dump(snpe_posterior, f)

In [None]:
try:
    with open('results/snpe_strong_prior.pkl', 'rb') as f:
        snpe_posterior = dill.load(f)
except:
    b_params = PRIOR.sample(sample_shape=(B, ))
    b_samples = simulator(b_params)
    b_params.shape, b_samples.shape
    snpe = SNPE(
        prior=PRIOR,
        density_estimator='maf',
        device='cpu'
    )

    _ = snpe.append_simulations(b_params, b_samples).train()
    snpe_posterior = snpe.build_posterior()
    with open('results/snpe_strong_prior.pkl', 'wb') as f:
        dill.dump(snpe_posterior, f)

In [None]:
obs_theta = torch.Tensor([0.5, 0.5])
obs_x = simulator(obs_theta)

kdeplots2D(
    [PRIOR.sample(sample_shape=(50_000, )), snpe_posterior.sample(sample_shape=(50_000, ), x=obs_x)],
    true_theta=obs_theta.unsqueeze(0),
    ignore_lower_than=1e-10, # set to None to see potential fat tails 
    xlim=(POI_BOUNDS[r'$\theta_1$'][0], POI_BOUNDS[r'$\theta_2$'][1]), 
    ylim=(POI_BOUNDS[r'$\theta_1$'][0], POI_BOUNDS[r'$\theta_2$'][1]),
    names=['Prior', 'Normalizing Flow'],
    axis_labels=[r'$\theta_1$', r'$\theta_2$']
)

### VSI

In [None]:
b_prime_params = REFERENCE.sample(sample_shape=(B_PRIME, ))
b_prime_samples = simulator(b_prime_params)
b_prime_params.shape, b_prime_samples.shape

In [None]:
try:
    with open('results/obs_x_theta.pkl', 'rb') as f:
        examples = dill.load(f)
        true_theta = examples['true_theta']
        obs_x = examples['obs_x']
except:
    true_theta = torch.Tensor([[0, 0], [0.5, -0.5], [-0.5, 0.5], [-0.5, -0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])
    obs_x = simulator(true_theta)
    with open('results/obs_x_theta.pkl', 'wb') as f:
        dill.dump({
            'true_theta': true_theta,
            'obs_x': obs_x
        }, f)

In [None]:
# try:
#     with open('results/lf2i_uniform_prior.pkl', 'rb') as f:
#         lf2i = dill.load(f)
#     confidence_sets = lf2i.inference(
#         x=obs_x,
#         evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
#         confidence_level=CONFIDENCE_LEVEL,
#         calibration_method='critical-values',
#         calibration_model='cat-gb',
#         calibration_model_kwargs={
#             # 'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
#             # 'n_iter': 25
#             'cv': {'iterations': [100], 'depth': [3]},
#             'n_iter': 1
#         },
#         T_prime=(b_prime_params, b_prime_samples),
#         retrain_calibration=False
#     )
# except:
#     lf2i = LF2I(test_statistic=Posterior(poi_dim=2, estimator=snpe_posterior, **POSTERIOR_KWARGS))
#     confidence_sets = lf2i.inference(
#         x=obs_x,
#         evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
#         confidence_level=CONFIDENCE_LEVEL,
#         calibration_method='critical-values',
#         calibration_model='cat-gb',
#         calibration_model_kwargs={
#             # 'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
#             # 'n_iter': 25
#             'cv': {'iterations': [100], 'depth': [3]},
#             'n_iter': 1
#         },
#         T_prime=(b_prime_params, b_prime_samples),
#         retrain_calibration=False
#     )
#     with open('results/lf2i_uniform_prior.pkl', 'wb') as f:
#         dill.dump(lf2i, f)

In [None]:
try:
    with open('results/lf2i_strong_prior.pkl', 'rb') as f:
        lf2i = dill.load(f)
    confidence_sets = lf2i.inference(
        x=obs_x,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='critical-values',
        calibration_model='cat-gb',
        calibration_model_kwargs={
            # 'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
            # 'n_iter': 25
            'cv': {'iterations': [100], 'depth': [3]},
            'n_iter': 1
        },
        T_prime=(b_prime_params, b_prime_samples),
        retrain_calibration=False
    )
except:
    lf2i = LF2I(test_statistic=Posterior(poi_dim=2, estimator=snpe_posterior, **POSTERIOR_KWARGS))
    confidence_sets = lf2i.inference(
        x=obs_x,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='critical-values',
        calibration_model='cat-gb',
        calibration_model_kwargs={
            # 'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
            # 'n_iter': 25
            'cv': {'iterations': [100], 'depth': [3]},
            'n_iter': 1
        },
        T_prime=(b_prime_params, b_prime_samples),
        retrain_calibration=False
    )
    with open('results/lf2i_strong_prior.pkl', 'wb') as f:
        dill.dump(lf2i, f)

In [None]:
try:
    with open('results/lf2i_strong_prior_waldo.pkl', 'rb') as f:
        lf2iw = dill.load(f)
    confidence_setsw = lf2iw.inference(
        x=obs_x,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='critical-values',
        calibration_model='cat-gb',
        calibration_model_kwargs={
            # 'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
            # 'n_iter': 25
            'cv': {'iterations': [100], 'depth': [3]},
            'n_iter': 1
        },
        T_prime=(b_prime_params, b_prime_samples),
        retrain_calibration=False
    )
except:
    lf2iw = LF2I(test_statistic=Waldo(poi_dim=2, estimator=snpe_posterior, estimation_method='posterior', num_posterior_samples=5_000, **POSTERIOR_KWARGS))
    confidence_setsw = lf2iw.inference(
        x=obs_x,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='critical-values',
        calibration_model='cat-gb',
        calibration_model_kwargs={
            # 'cv': {'iterations': [100, 300, 500, 700, 1000], 'depth': [1, 3, 5, 7, 9]},
            # 'n_iter': 25
            'cv': {'iterations': [100], 'depth': [3]},
            'n_iter': 1
        },
        T_prime=(b_prime_params, b_prime_samples),
        retrain_calibration=False
    )
    with open('results/lf2i_strong_prior_waldo.pkl', 'wb') as f:
        dill.dump(lf2iw, f)

In [None]:
remaining = len(obs_x)
credible_sets = []
for x in obs_x:  # torch.vstack([task.get_observation(i) for i in range(1, 11)])
    print(f'Remaining: {remaining}', flush=True)
    credible_sets_x = []
    for cl in CONFIDENCE_LEVEL:
        actual_cred_level, credible_set = hpd_region(
            posterior=snpe_posterior,
            param_grid=EVAL_GRID_DISTR.sample(sample_shape=(EVAL_GRID_SIZE, )),
            x=x.reshape(-1, ),
            credible_level=cl,
            num_level_sets=10_000,
            **POSTERIOR_KWARGS
        )
        #print(actual_cred_level, flush=True)
        credible_sets_x.append(credible_set)
    credible_sets.append(credible_sets_x)
    remaining -= 1

In [None]:
plt.rc('text', usetex=True)  # Enable LaTeX
plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
plt.rcParams['text.latex.preamble'] = r'''
    \usepackage{amsmath}  % For \mathbb
    \usepackage{amssymb}  % For \mathbb
    \usepackage{bm}       % For bold math symbols
    \usepackage{underscore} % If underscores are needed
'''

for idx_obs in range(8):

    if idx_obs <= 4:
        title = r'\textbf{a)} Prior poorly aligned with $\theta^{\star}$'
    else:
        title = r'\textbf{b)} Prior well aligned with $\theta^{\star}$'

    plot_parameter_regions(
        *credible_sets[idx_obs], #*[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
        param_dim=2,
        true_parameter=true_theta[idx_obs, :],
        prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
        parameter_space_bounds={
            r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
            r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
        },
        # parameter_space_bounds={
        #     r'$\theta_1$': dict(zip(['low', 'high'], [-1.0, 1.0])), 
        #     r'$\theta_2$': dict(zip(['low', 'high'], [-1.0, 1.0])), 
        # },
        colors=[
            'purple', 'deeppink', # 'hotpink',  # credible sets
            #'teal', 'mediumseagreen', 'darkseagreen', # confidence sets
        ],
        region_names=[
            *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
            #*[f'CS {cl*100:.1f}%' for cl in CONFIDENCE_LEVEL],
        ],
        labels=[r'$\theta_1$', r'$\theta_2$'],
        linestyles=['-', '--'],  # , ':'
        param_names=[r'$\theta_1$', r'$\theta_2$'],
        alpha_shape=False,
        alpha=3,
        scatter=True,
        figsize=(5, 5),
        # save_fig_path=f'./results/sbibm_example/hpd{idx_obs}.pdf',
        remove_legend=True,
        title=title,
        custom_ax=None
    )

In [None]:
for idx_obs in range(8):

    plot_parameter_regions(
        *[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
        param_dim=2,
        true_parameter=true_theta[idx_obs, :],
        prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
        parameter_space_bounds={
            r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
            r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
        },
        colors=[
            #'purple', 'deeppink', 'hotpink',  # credible sets
            'teal', 'mediumseagreen', # 'darkseagreen', # confidence sets
        ],
        region_names=[
            #*[f'HPD {cl*100:.1f}%' for cl in CONFIDENCE_LEVEL],
            *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
        ],
        labels=[r'$\theta_1$', r'$\theta_2$'],
        linestyles=['-', '--'],  # , ':'
        param_names=[r'$\theta_1$', r'$\theta_2$'],
        alpha_shape=False,
        alpha=3,
        scatter=True,
        figsize=(5, 5),
        # save_fig_path=f'./results/sbibm_example/freb{idx_obs}.pdf',
        remove_legend=True,
        title='FreB with Posterior',
        custom_ax=None
    )

In [None]:
for idx_obs in range(8):

    plot_parameter_regions(
        *[confidence_setsw[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
        param_dim=2,
        true_parameter=true_theta[idx_obs, :],
        prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
        parameter_space_bounds={
            r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
            r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
        },
        colors=[
            #'purple', 'deeppink', 'hotpink',  # credible sets
            'teal', 'mediumseagreen', # 'darkseagreen', # confidence sets
        ],
        region_names=[
            #*[f'HPD {cl*100:.1f}%' for cl in CONFIDENCE_LEVEL],
            *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
        ],
        labels=[r'$\theta_1$', r'$\theta_2$'],
        linestyles=['-', '--'],  # , ':'
        param_names=[r'$\theta_1$', r'$\theta_2$'],
        alpha_shape=False,
        alpha=3,
        scatter=True,
        figsize=(5, 5),
        # save_fig_path=f'./results/sbibm_example/freb{idx_obs}.pdf',
        remove_legend=True,
        title='FreB with Waldo',
        custom_ax=None
    )

### COVERAGE DIAGNOSTICS

In [None]:
try:
    with open('results/diagn_confset_strong_prior.pkl', 'rb') as f:
        diagn_objects = dill.load(f)
    with open('results/diagn_confset_strong_prior_waldo.pkl', 'rb') as f:
        diagn_objectsw = dill.load(f)
    with open('results/diagn_cred_strong_prior.pkl', 'rb') as f:
        diagn_objects_cred = dill.load(f)
    with open('results/b_double_prime.pkl', 'rb') as f:
        b_double_prime = dill.load(f)
        b_double_prime_params, b_double_prime_samples = b_double_prime['params'], b_double_prime['samples']
except:
    b_double_prime_params = REFERENCE.sample(sample_shape=(B_DOUBLE_PRIME, ))
    b_double_prime_samples = simulator(b_double_prime_params)
    b_double_prime_params.shape, b_double_prime_samples.shape
    with open('results/b_double_prime.pkl', 'wb') as f:
        dill.dump({
            'params': b_double_prime_params,
            'samples': b_double_prime_samples
        }, f)

    diagn_objects = {}
    for cl in CONFIDENCE_LEVEL[:1]:  # 0.954
        print(cl, flush=True)
        diagnostics_estimator_confset, out_parameters_confset, mean_proba_confset, upper_proba_confset, lower_proba_confset = lf2i.diagnostics(
            region_type='lf2i',
            confidence_level=cl,
            calibration_method='critical-values',
            coverage_estimator='splines',
            T_double_prime=(b_double_prime_params, b_double_prime_samples),
        )
        diagn_objects[cl] = (diagnostics_estimator_confset, out_parameters_confset, mean_proba_confset, upper_proba_confset, lower_proba_confset)
    with open('results/diagn_confset_strong_prior.pkl', 'wb') as f:
        dill.dump(diagn_objects, f)

    plt.scatter(out_parameters_confset[:, 0], out_parameters_confset[:, 1], c=mean_proba_confset)
    plt.title('Coverage of FreB confidence sets')
    plt.clim(vmin=0, vmax=1)
    plt.colorbar()
    plt.savefig('results/freb_coverage.png')
    plt.close()

    diagn_objectsw = {}
    for cl in CONFIDENCE_LEVEL[:1]:  # 0.954
        print(cl, flush=True)
        diagnostics_estimator_confset, out_parameters_confset, mean_proba_confset, upper_proba_confset, lower_proba_confset = lf2iw.diagnostics(
            region_type='lf2i',
            confidence_level=cl,
            calibration_method='critical-values',
            coverage_estimator='splines',
            T_double_prime=(b_double_prime_params, b_double_prime_samples),
        )
        diagn_objects[cl] = (diagnostics_estimator_confset, out_parameters_confset, mean_proba_confset, upper_proba_confset, lower_proba_confset)
    with open('results/diagn_confset_strong_prior_waldo.pkl', 'wb') as f:
        dill.dump(diagn_objectsw, f)

    plt.scatter(out_parameters_confset[:, 0], out_parameters_confset[:, 1], c=mean_proba_confset)
    plt.title('Coverage of Waldo confidence sets')
    plt.clim(vmin=0, vmax=1)
    plt.colorbar()
    plt.savefig('results/waldo_coverage.png')
    plt.close()

    diagn_objects_cred = {}
    size_grid_for_sizes = 5_000
    for cl in CONFIDENCE_LEVEL[:1]:  # 0.954
        print(cl, flush=True)
        diagnostics_estimator_credible, out_parameters_credible, mean_proba_credible, upper_proba_credible, lower_proba_credible, sizes = lf2i.diagnostics(
            region_type='posterior',
            confidence_level=cl,
            coverage_estimator='splines',
            T_double_prime=(b_double_prime_params, b_double_prime_samples),
            posterior_estimator=lf2i.test_statistic.estimator,
            evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(size_grid_for_sizes, )),
            num_level_sets=5_000,
            **POSTERIOR_KWARGS
        )
        diagn_objects_cred[cl] = (diagnostics_estimator_credible, out_parameters_credible, mean_proba_credible, upper_proba_credible, lower_proba_credible, sizes)
    with open('results/diagn_cred_strong_prior.pkl', 'wb') as f:
        dill.dump(diagn_objects_cred, f)

    plt.scatter(out_parameters_credible[:, 0], out_parameters_credible[:, 1], c=mean_proba_credible)
    plt.title('Coverage of credible regions')
    plt.clim(vmin=0, vmax=1)
    plt.colorbar()
    plt.savefig('results/hpd_coverage.png')
    plt.close()

In [None]:
try:
    with open('results/confidence_sets_for_size.pkl', 'rb') as f:
        confidence_sets_for_size = dill.load(f)
except:
    size_grid_for_sizes = 5_000
    confidence_sets_for_size = lf2i.inference(
        x=b_double_prime_samples,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(size_grid_for_sizes, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='critical-values',
    )
    confset_sizes = np.array([100*cs.shape[0]/size_grid_for_sizes for cs in confidence_sets_for_size[0]])
    with open('results/confidence_sets_for_size.pkl', 'wb') as f:
        dill.dump(confidence_sets_for_size, f)

    plt.scatter(b_double_prime_params[:, 0], b_double_prime_params[:, 1], c=confset_sizes)
    plt.clim(0, 100)
    plt.colorbar()
    plt.title('FreB sizes, 1-100\%')
    plt.savefig('results/freb_sizes_fixed_scale.png')
    plt.close()

    plt.scatter(b_double_prime_params[:, 0], b_double_prime_params[:, 1], c=confset_sizes)
    plt.colorbar()
    plt.title('FreB sizes')
    plt.savefig('results/freb_sizes.png')
    plt.close()

In [None]:
try:
    with open('results/confidence_sets_for_size_waldo.pkl', 'rb') as f:
        confidence_sets_for_size = dill.load(f)
except:
    size_grid_for_sizes = 5_000
    confidence_sets_for_size = lf2iw.inference(
        x=b_double_prime_samples,
        evaluation_grid=EVAL_GRID_DISTR.sample(sample_shape=(size_grid_for_sizes, )),
        confidence_level=CONFIDENCE_LEVEL,
        calibration_method='critical-values',
    )
    size_grid_for_sizes = 5_000
    confset_sizes = np.array([100*cs.shape[0]/size_grid_for_sizes for cs in confidence_sets_for_size[0]])
    with open('results/confidence_sets_for_size_waldo.pkl', 'wb') as f:
        dill.dump(confidence_sets_for_size, f)

    plt.scatter(b_double_prime_params[:, 0], b_double_prime_params[:, 1], c=confset_sizes)
    plt.clim(0, 100)
    plt.colorbar()
    plt.title('Waldo sizes, 1-100\%')
    plt.savefig('results/waldo_sizes_fixed_scale.png')
    plt.close()

    plt.scatter(b_double_prime_params[:, 0], b_double_prime_params[:, 1], c=confset_sizes)
    plt.colorbar()
    plt.title('Waldo sizes')
    plt.savefig('results/waldo_sizes.png')
    plt.close()

### Final plots

In [None]:
plt.rc('text', usetex=True)  # Enable LaTeX
plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
plt.rcParams['text.latex.preamble'] = r'''
    \usepackage{amsmath}  % For \mathbb
    \usepackage{amssymb}  % For \mathbb
    \usepackage{bm}       % For bold math symbols
    \usepackage{underscore} % If underscores are needed
'''

fig, ax = plt.subplots(2, 3, figsize=(25, 16))
fig.subplots_adjust(hspace=0.25)

plot_parameter_regions(
    *credible_sets[1],
    param_dim=2,
    true_parameter=true_theta[1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'purple', 'deeppink', #'hotpink',  # credible sets
    ],
    region_names=[
        *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=False,
    custom_ax=ax[0][0]
)
ax[0][0].set_xticklabels([])
ax[0][0].set_xlabel('')
ax[0][0].set_ylabel(r'$\theta_2$', fontsize=45)
ax[0][0].tick_params(labelsize=30)
ax[0][0].set_title(r'\textbf{Misaligned Prior}', size=50, pad=43)

plot_parameter_regions(
    *[confidence_sets[j][1] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen', #'darkseagreen', # confidence sets
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=False,
    custom_ax=ax[1][0]
)
ax[1][0].tick_params(labelsize=30)
ax[1][0].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][0].set_ylabel(r'$\theta_2$', fontsize=45)


plot_parameter_regions(
    *credible_sets[-1], #*[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[-1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'purple', 'deeppink', #'hotpink',  # credible sets
    ],
    region_names=[
        *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=True,
    custom_ax=ax[0][1]
)
ax[0][1].set_xticklabels([])
ax[0][1].set_yticklabels([])
ax[0][1].set_xlabel('')
ax[0][1].set_ylabel('', fontsize=45)
ax[0][1].tick_params(labelsize=30)
ax[0][1].set_title(r'\textbf{Well-Aligned Prior}', size=50, pad=43)

plot_parameter_regions(
    *[confidence_sets[j][-1] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[-1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen', #'darkseagreen', # confidence sets
    ],
    region_names=[
        #*[f'HPD {cl*100:.1f}%' for cl in CONFIDENCE_LEVEL],
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=True,
    custom_ax=ax[1][1]
)
ax[1][1].tick_params(labelsize=30)
ax[1][1].set_yticklabels([])
ax[1][1].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][1].set_ylabel('', fontsize=45)


hpd_diagn_plot = coverage_probability_plot(
    parameters=diagn_objects_cred[CONFIDENCE_LEVEL[0]][1],
    coverage_probability=diagn_objects_cred[CONFIDENCE_LEVEL[0]][2],
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    xlims=(-10, 10),
    ylims=(-10, 10),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    #figsize=(9, 7),
    title=None,
    show_text=False,
    custom_ax=ax[0][2],
)
ax[0][2].set_xlim(-10, 10)
ax[0][2].set_ylim(-10, 10)
ax[0][2].set_xticks(np.linspace(-10, 10, 5).astype(int))
ax[0][2].set_yticks(np.linspace(-10, 10, 5).astype(int))
ax[0][2].set_xticklabels([])
ax[0][2].set_yticklabels([])
ax[0][2].set_xlabel('')
ax[0][2].set_ylabel('', fontsize=45, rotation=0)
ax[0][2].tick_params(labelsize=30)
ax[0][2].set_title(r'\textbf{Local Coverage}', size=50, pad=43)


_ = coverage_probability_plot(
    parameters=diagn_objects[CONFIDENCE_LEVEL[0]][1],
    coverage_probability=diagn_objects[CONFIDENCE_LEVEL[0]][2],
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=(-10, 10),
    ylims=(-10, 10),
    #figsize=(9, 7),
    title=None,
    show_text=False,
    custom_ax=ax[1][2],
)
ax[1][2].set_xlim(-10, 10)
ax[1][2].set_ylim(-10, 10)
ax[1][2].set_xticks(np.linspace(-10, 10, 5).astype(int))
ax[1][2].set_yticks(np.linspace(-10, 10, 5).astype(int))
ax[1][2].tick_params(labelsize=30)
ax[1][2].set_yticklabels([])
ax[1][2].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][2].set_ylabel('', rotation=0, fontsize=45)

cax = fig.add_axes([0.97, 0.343, 0.015, 0.3])  # Adjust these values to move the colorbar
cbar = fig.colorbar(hpd_diagn_plot, format='%1.2f', cax=cax)
standard_ticks = np.round(np.linspace(0, 100, num=6), 1)
all_ticks = np.unique(np.sort(np.append(standard_ticks[:-1], CONFIDENCE_LEVEL[0] * 100)))
tick_labels = [f"{label:.0f}\%" for label in all_ticks]
for i, label in enumerate(all_ticks):
    if abs(label - CONFIDENCE_LEVEL[0]*100) <= 1e-6:
        tick_labels[i] = r"$\mathbf{{{label}}}$\textbf{{\%}}".format(label=int(label))
cbar.ax.yaxis.set_ticks(all_ticks)
cbar.ax.set_yticklabels(tick_labels, fontsize=45)
cbar.ax.axhline(y=CONFIDENCE_LEVEL[0]*100, xmin=0, xmax=1, color="black", linestyle="--", linewidth=2.5)
cbar.ax.yaxis.set_ticks_position('left')
cbar.ax.yaxis.set_label_position('left')


fig.patches.append(patches.Rectangle((0.079, 0.503), 0.917, 0.395, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))
fig.patches.append(patches.Rectangle((0.079, 0.035), 0.917, 0.447, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))

line = mlines.Line2D(
    [0.65, 0.65],   # x-coords (start, end)
    [0.15, 0.85],               # y-coords (start, end)
    transform=fig.transFigure,
    ls='-',                  # dashed
    lw=2.5,
    color='black',
    clip_on=False
)

fig.add_artist(line)

fig.add_artist(FancyArrowPatch(
    posA=(0.239, 0.523), posB=(0.239, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))
fig.add_artist(FancyArrowPatch(
    posA=(0.513, 0.523), posB=(0.513, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))
# fig.add_artist(FancyArrowPatch(
#     posA=(0.787, 0.523), posB=(0.787, 0.455),
#     connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
#     color='black', linewidth=6, zorder=10
# ))

plt.savefig('./results/sbibm_example/example0_horizontal.pdf', bbox_inches='tight')
plt.show()

In [None]:
plt.rc('text', usetex=True)  # Enable LaTeX
plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
plt.rcParams['text.latex.preamble'] = r'''
    \usepackage{amsmath}  % For \mathbb
    \usepackage{amssymb}  % For \mathbb
    \usepackage{bm}       % For bold math symbols
    \usepackage{underscore} % If underscores are needed
'''

fig, ax = plt.subplots(2, 3, figsize=(25, 16))
fig.subplots_adjust(hspace=0.25)

plot_parameter_regions(
    *credible_sets[1],
    param_dim=2,
    true_parameter=true_theta[1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'purple', 'deeppink', #'hotpink',  # credible sets
    ],
    region_names=[
        *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=False,
    custom_ax=ax[0][0]
)
ax[0][0].set_xticklabels([])
ax[0][0].set_xlabel('')
ax[0][0].set_ylabel(r'$\theta_2$', fontsize=45)
ax[0][0].tick_params(labelsize=30)
ax[0][0].set_title(r'\textbf{Misaligned Prior}', size=50, pad=43)

plot_parameter_regions(
    *[confidence_sets[j][1] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen', #'darkseagreen', # confidence sets
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=False,
    custom_ax=ax[1][0]
)
ax[1][0].tick_params(labelsize=30)
ax[1][0].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][0].set_ylabel(r'$\theta_2$', fontsize=45)


plot_parameter_regions(
    *credible_sets[-1], #*[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[-1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'purple', 'deeppink', #'hotpink',  # credible sets
    ],
    region_names=[
        *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=True,
    custom_ax=ax[0][1]
)
ax[0][1].set_xticklabels([])
ax[0][1].set_yticklabels([])
ax[0][1].set_xlabel('')
ax[0][1].set_ylabel('', fontsize=45)
ax[0][1].tick_params(labelsize=30)
ax[0][1].set_title(r'\textbf{Well-Aligned Prior}', size=50, pad=43)

plot_parameter_regions(
    *[confidence_sets[j][-1] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[-1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen', #'darkseagreen', # confidence sets
    ],
    region_names=[
        #*[f'HPD {cl*100:.1f}%' for cl in CONFIDENCE_LEVEL],
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=True,
    custom_ax=ax[1][1]
)
ax[1][1].tick_params(labelsize=30)
ax[1][1].set_yticklabels([])
ax[1][1].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][1].set_ylabel('', fontsize=45)


hpd_diagn_plot = coverage_probability_plot(
    parameters=diagn_objects_cred[CONFIDENCE_LEVEL[0]][1],
    coverage_probability=diagn_objects_cred[CONFIDENCE_LEVEL[0]][2],
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    xlims=(-10, 10),
    ylims=(-10, 10),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    #figsize=(9, 7),
    title=None,
    show_text=False,
    custom_ax=ax[0][2],
)
ax[0][2].set_xlim(-10, 10)
ax[0][2].set_ylim(-10, 10)
ax[0][2].set_xticks(np.linspace(-10, 10, 5).astype(int))
ax[0][2].set_yticks(np.linspace(-10, 10, 5).astype(int))
ax[0][2].set_xticklabels([])
ax[0][2].set_yticklabels([])
ax[0][2].set_xlabel('')
ax[0][2].set_ylabel('', fontsize=45, rotation=0)
ax[0][2].tick_params(labelsize=30)
ax[0][2].set_title(r'\textbf{Local Coverage}', size=50, pad=43)


_ = coverage_probability_plot(
    parameters=diagn_objects[CONFIDENCE_LEVEL[0]][1],
    coverage_probability=diagn_objects[CONFIDENCE_LEVEL[0]][2],
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=(-10, 10),
    ylims=(-10, 10),
    #figsize=(9, 7),
    title=None,
    show_text=False,
    custom_ax=ax[1][2],
)
ax[1][2].set_xlim(-10, 10)
ax[1][2].set_ylim(-10, 10)
ax[1][2].set_xticks(np.linspace(-10, 10, 5).astype(int))
ax[1][2].set_yticks(np.linspace(-10, 10, 5).astype(int))
ax[1][2].tick_params(labelsize=30)
ax[1][2].set_yticklabels([])
ax[1][2].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][2].set_ylabel('', rotation=0, fontsize=45)

cax = fig.add_axes([0.97, 0.343, 0.015, 0.3])  # Adjust these values to move the colorbar
cbar = fig.colorbar(hpd_diagn_plot, format='%1.2f', cax=cax)
standard_ticks = np.round(np.linspace(0, 100, num=6), 1)
all_ticks = np.unique(np.sort(np.append(standard_ticks[:-1], CONFIDENCE_LEVEL[0] * 100))) # all_ticks = standard_ticks # 
tick_labels = [f"{label:.0f}\%" for label in all_ticks]
for i, label in enumerate(all_ticks):
    if abs(label - CONFIDENCE_LEVEL[0]*100) <= 1e-6:
        tick_labels[i] = r"$\mathbf{{{label}}}$\textbf{{\%}}".format(label=int(label))
cbar.ax.yaxis.set_ticks(all_ticks)
cbar.ax.set_yticklabels(tick_labels, fontsize=45)
cbar.ax.axhline(y=CONFIDENCE_LEVEL[0]*100, xmin=0, xmax=1, color="black", linestyle="--", linewidth=2.5)
cbar.ax.yaxis.set_ticks_position('left')  # Move ticks to the left
cbar.ax.yaxis.set_label_position('left')  # Move label to the left


fig.patches.append(patches.Rectangle((0.079, 0.495), 0.917, 0.41, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))
fig.patches.append(patches.Rectangle((0.079, 0.035), 0.917, 0.453, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))

line = mlines.Line2D(
    [0.65, 0.65],   # x-coords (start, end)
    [0.15, 0.85],               # y-coords (start, end)
    transform=fig.transFigure,
    ls='-',                  # dashed
    lw=2.5,
    color='black',
    clip_on=False
)

fig.add_artist(line)

fig.add_artist(FancyArrowPatch(
    posA=(0.239, 0.523), posB=(0.239, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))
fig.add_artist(FancyArrowPatch(
    posA=(0.513, 0.523), posB=(0.513, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))
fig.add_artist(FancyArrowPatch(
    posA=(0.787, 0.523), posB=(0.787, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))

plt.savefig('./results/sbibm_example/example0_horizontal_mid_prior.pdf', bbox_inches='tight')
plt.show()

In [None]:
plt.rc('text', usetex=True)  # Enable LaTeX
plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
plt.rcParams['text.latex.preamble'] = r'''
    \usepackage{amsmath}  % For \mathbb
    \usepackage{amssymb}  % For \mathbb
    \usepackage{bm}       % For bold math symbols
    \usepackage{underscore} % If underscores are needed
'''

fig, ax = plt.subplots(2, 3, figsize=(25, 16))
fig.subplots_adjust(hspace=0.25)

plot_parameter_regions(
    *credible_sets[1],
    param_dim=2,
    true_parameter=true_theta[1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'purple', 'deeppink', #'hotpink',  # credible sets
    ],
    region_names=[
        *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=False,
    custom_ax=ax[0][0]
)
ax[0][0].set_xticklabels([])
ax[0][0].set_xlabel('')
ax[0][0].set_ylabel(r'$\theta_2$', fontsize=45)
ax[0][0].tick_params(labelsize=30)
ax[0][0].set_title(r'\textbf{Misaligned Prior}', size=50, pad=43)

plot_parameter_regions(
    *[confidence_sets[j][1] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen', #'darkseagreen', # confidence sets
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=False,
    custom_ax=ax[1][0]
)
ax[1][0].tick_params(labelsize=30)
ax[1][0].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][0].set_ylabel(r'$\theta_2$', fontsize=45)


plot_parameter_regions(
    *credible_sets[-1], #*[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[-1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'purple', 'deeppink', #'hotpink',  # credible sets
    ],
    region_names=[
        *[f'HPD {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=True,
    custom_ax=ax[0][1]
)
ax[0][1].set_xticklabels([])
ax[0][1].set_yticklabels([])
ax[0][1].set_xlabel('')
ax[0][1].set_ylabel('', fontsize=45)
ax[0][1].tick_params(labelsize=30)
ax[0][1].set_title(r'\textbf{Well-Aligned Prior}', size=50, pad=43)

plot_parameter_regions(
    *[confidence_sets[j][-1] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    true_parameter=true_theta[-1, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen', #'darkseagreen', # confidence sets
    ],
    region_names=[
        #*[f'HPD {cl*100:.1f}%' for cl in CONFIDENCE_LEVEL],
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'], #':', 
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=3,
    scatter=False,
    figsize=(5, 5),
    remove_legend=True,
    custom_ax=ax[1][1]
)
ax[1][1].tick_params(labelsize=30)
ax[1][1].set_yticklabels([])
ax[1][1].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][1].set_ylabel('', fontsize=45)


hpd_diagn_plot = coverage_probability_plot(
    parameters=diagn_objects_cred[CONFIDENCE_LEVEL[0]][1],
    coverage_probability=diagn_objects_cred[CONFIDENCE_LEVEL[0]][2],
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    xlims=(-10, 10),
    ylims=(-10, 10),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    #figsize=(9, 7),
    title=None,
    show_text=False,
    custom_ax=ax[0][2],
)
ax[0][2].set_xlim(-10, 10)
ax[0][2].set_ylim(-10, 10)
ax[0][2].set_xticks(np.linspace(-10, 10, 5).astype(int))
ax[0][2].set_yticks(np.linspace(-10, 10, 5).astype(int))
ax[0][2].set_xticklabels([])
ax[0][2].set_yticklabels([])
ax[0][2].set_xlabel('')
ax[0][2].set_ylabel('', fontsize=45, rotation=0)
ax[0][2].tick_params(labelsize=30)
ax[0][2].set_title(r'\textbf{Local Coverage}', size=50, pad=43)


_ = coverage_probability_plot(
    parameters=diagn_objects[CONFIDENCE_LEVEL[0]][1],
    coverage_probability=diagn_objects[CONFIDENCE_LEVEL[0]][2],
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=(-10, 10),
    ylims=(-10, 10),
    #figsize=(9, 7),
    title=None,
    show_text=False,
    custom_ax=ax[1][2],
)
ax[1][2].set_xlim(-10, 10)
ax[1][2].set_ylim(-10, 10)
ax[1][2].set_xticks(np.linspace(-10, 10, 5).astype(int))
ax[1][2].set_yticks(np.linspace(-10, 10, 5).astype(int))
ax[1][2].tick_params(labelsize=30)
ax[1][2].set_yticklabels([])
ax[1][2].set_xlabel(r'$\theta_1$', fontsize=45)
ax[1][2].set_ylabel('', rotation=0, fontsize=45)

cax = fig.add_axes([0.97, 0.343, 0.015, 0.3])  # Adjust these values to move the colorbar
cbar = fig.colorbar(hpd_diagn_plot, format='%1.2f', cax=cax)
standard_ticks = np.round(np.linspace(0, 100, num=6), 1)
all_ticks = np.unique(np.sort(np.append(standard_ticks[:-1], CONFIDENCE_LEVEL[0] * 100))) # all_ticks = standard_ticks # 
tick_labels = [f"{label:.0f}\%" for label in all_ticks]
for i, label in enumerate(all_ticks):
    if abs(label - CONFIDENCE_LEVEL[0]*100) <= 1e-6:
        tick_labels[i] = r"$\mathbf{{{label}}}$\textbf{{\%}}".format(label=int(label))
cbar.ax.yaxis.set_ticks(all_ticks)
cbar.ax.set_yticklabels(tick_labels, fontsize=45)
cbar.ax.axhline(y=CONFIDENCE_LEVEL[0]*100, xmin=0, xmax=1, color="black", linestyle="--", linewidth=2.5)
cbar.ax.yaxis.set_ticks_position('left')  # Move ticks to the left
cbar.ax.yaxis.set_label_position('left')  # Move label to the left


fig.patches.append(patches.Rectangle((0.079, 0.495), 0.917, 0.41, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))
fig.patches.append(patches.Rectangle((0.079, 0.035), 0.917, 0.453, transform=fig.transFigure, edgecolor='black', linewidth=2, facecolor="gainsboro", zorder=-1))

line = mlines.Line2D(
    [0.65, 0.65],   # x-coords (start, end)
    [0.15, 0.85],               # y-coords (start, end)
    transform=fig.transFigure,
    ls='-',                  # dashed
    lw=2.5,
    color='black',
    clip_on=False
)

fig.add_artist(line)

fig.add_artist(FancyArrowPatch(
    posA=(0.239, 0.523), posB=(0.239, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))
fig.add_artist(FancyArrowPatch(
    posA=(0.513, 0.523), posB=(0.513, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))
fig.add_artist(FancyArrowPatch(
    posA=(0.787, 0.523), posB=(0.787, 0.455),
    connectionstyle="arc3,rad=0", arrowstyle='-|>', mutation_scale=30, 
    color='black', linewidth=6, zorder=10
))

plt.savefig('./results/sbibm_example/example0_horizontal_weak_prior.pdf', bbox_inches='tight')
plt.show()