In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from tqdm import tqdm
import dill
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.lines as mlines
from matplotlib.patches import FancyArrowPatch
import seaborn as sns
import torch
from torch.distributions import MultivariateNormal
from sbi.inference import FMPE, SNPE, NPSE
from sbi.analysis import pairplot
from sbi.utils import BoxUniform
import sbibm
from lf2i.inference import LF2I
from lf2i.test_statistics.posterior import Posterior
from lf2i.test_statistics.waldo import Waldo
from lf2i.calibration.critical_values import train_qr_algorithm
from lf2i.utils.other_methods import hpd_region
from lf2i.plot.parameter_regions import plot_parameter_regions
from lf2i.plot.coverage_diagnostics import coverage_probability_plot
from lf2i.plot.power_diagnostics import set_size_plot
from tsi.temp.utils import kdeplots2D

In [None]:
plt.rc('text', usetex=True)  # Enable LaTeX
plt.rc('font', family='serif')  # Use a serif font (e.g., Computer Modern)
plt.rcParams['text.latex.preamble'] = r'''
    \usepackage{amsmath}  % For \mathbb
    \usepackage{amssymb}  % For \mathbb
    \usepackage{bm}       % For bold math symbols
    \usepackage{underscore} % If underscores are needed
'''

In [None]:
### Settings
POI_DIM = 2  # parameter of interest
PRIOR_LOC = [0, 0]
PRIOR_VAR = 0.1 # (6*np.sqrt(2.0))**2
POI_BOUNDS = {r'$\theta_1$': (-1, 1), r'$\theta_2$': (-1, 1)}
PRIOR = BoxUniform(
    low=torch.tensor((POI_BOUNDS[r'$\theta_1$'][0]-1, POI_BOUNDS[r'$\theta_2$'][0]-1)),
    high=torch.tensor((POI_BOUNDS[r'$\theta_1$'][1]+1, POI_BOUNDS[r'$\theta_2$'][1]+1))
)

B = 100_000  # num simulations to estimate posterior and test statistics
B_PRIME = 50_000  # num simulations to estimate critical values
B_DOUBLE_PRIME = 30_000  # num simulations to do diagnostics
EVAL_GRID_SIZE = 1_000  # num evaluation points over parameter space to construct confidence sets
CONFIDENCE_LEVEL = 0.954, 0.683  # 0.99

REFERENCE = BoxUniform(
    low=torch.tensor((POI_BOUNDS[r'$\theta_1$'][0]-1, POI_BOUNDS[r'$\theta_2$'][0]-1)),
    high=torch.tensor((POI_BOUNDS[r'$\theta_1$'][1]+1, POI_BOUNDS[r'$\theta_2$'][1]+1))
)
EVAL_GRID_DISTR = BoxUniform(
    low=torch.tensor((POI_BOUNDS[r'$\theta_1$'][0], POI_BOUNDS[r'$\theta_2$'][0])),
    high=torch.tensor((POI_BOUNDS[r'$\theta_1$'][1], POI_BOUNDS[r'$\theta_2$'][1]))
)

POSTERIOR_KWARGS = {
    'norm_posterior': None
}
DEVICE = 'cpu'
task = sbibm.get_task('two_moons')
simulator = task.get_simulator()

# Strong prior

In [None]:
# experiment_dir = 'results/snpe/uniform_prior'
experiment_dir = 'results/fmpe/strong_prior/2025-11-05 12:42:14.771084'

with open(f'{experiment_dir}/obs_x_theta.pkl', 'rb') as f:
    examples = dill.load(f)
    true_theta = examples['true_theta']
    obs_x = examples['obs_x']
with open(f'{experiment_dir}/lf2i_strong_prior.pkl', 'rb') as f:
    lf2i = dill.load(f)

PRIOR_VAR = 0.1 # (6*np.sqrt(2.0))**2
POI_BOUNDS = {r'$\theta_1$': (-1, 1), r'$\theta_2$': (-1, 1)}
PRIOR = MultivariateNormal(
    loc=torch.Tensor([0.5, 0.5]), covariance_matrix=PRIOR_VAR*torch.eye(n=POI_DIM)
)

size_grid_for_sizes = 1_000
with open(f'{experiment_dir}/confidence_sets_for_size_waldo.pkl', 'rb') as f:
    confidence_sets_for_size_waldo = dill.load(f)
    confset_sizes_waldo = np.array([100*cs.shape[0]/size_grid_for_sizes for cs in confidence_sets_for_size_waldo[0]])
with open(f'{experiment_dir}/confidence_sets_for_size.pkl', 'rb') as f:
    confidence_sets_for_size = dill.load(f)
    confset_sizes = np.array([100*cs.shape[0]/size_grid_for_sizes for cs in confidence_sets_for_size[0]])
with open(f'{experiment_dir}/set_for_size.pkl', 'rb') as f:
    set_for_size = dill.load(f)
    params_for_size = set_for_size['params']
    samples_for_size = set_for_size['samples']
colorbar_max = 50

with open(f'{experiment_dir}/confidence_sets.pkl', 'rb') as f:
    confidence_sets = dill.load(f)
with open(f'{experiment_dir}/confidence_sets_waldo.pkl', 'rb') as f:
    confidence_setsw = dill.load(f)
with open(f'{experiment_dir}/strong_diagnostics_posterior.pkl', 'rb') as f:
    estimator, out_parameters, mean_proba, upper_proba, lower_proba = dill.load(f)

idx_1 = 6
idx_2 = 3

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
import numpy as np

fig = plt.figure(figsize=(20, 10))
# Adjusted gridspec with more space after third column
gs = GridSpec(2, 5, figure=fig, left=0.08, top=0.91, bottom=0.12, right=0.92, 
              hspace=0.35, wspace=0.4, width_ratios=[1, 1, 1, 0.06, 1])

# Create axes from gridspec
axs = [[fig.add_subplot(gs[i, j]) for j in [0, 1, 2, 4]] for i in range(2)]

# Add gainsboro patches behind each row (non-overlapping)
# Top row patch
fig.patches.append(patches.Rectangle(
    (0.03, 0.5), 0.967, 0.47, 
    transform=fig.transFigure, 
    edgecolor='black', 
    linewidth=2, 
    facecolor="gainsboro", 
    zorder=-1
))

# Bottom row patch
fig.patches.append(patches.Rectangle(
    (0.03, 0.03), 0.967, 0.45, 
    transform=fig.transFigure, 
    edgecolor='black', 
    linewidth=2, 
    facecolor="gainsboro", 
    zorder=-1
))

idx_obs = idx_1
plot_parameter_regions(
    *[confidence_setsw[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    # true_parameter=params_for_size[idx_obs, :],
    true_parameter = true_theta[idx_obs, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        '#64B5F6', '#1976D2'
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=r'$\mathbf{First\ mode}$',
    custom_ax=axs[0][0]
)

plot_parameter_regions(
    *[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    # true_parameter=params_for_size[idx_obs, :],
    true_parameter = true_theta[idx_obs, :],
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen',
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=None,
    custom_ax=axs[1][0]
)

idx_obs = idx_2
# true_parameter=params_for_size[idx_obs, :]
true_parameter = true_theta[idx_obs, :]
plot_parameter_regions(
    *[confidence_setsw[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        '#64B5F6', '#1976D2'
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=r'$\mathbf{Second\ mode}$',
    custom_ax=axs[0][1]
)
axs[0][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

plot_parameter_regions(
    *[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    prior_samples=PRIOR.sample(sample_shape=(50_000, )).numpy(),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen',
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=None,
    custom_ax=axs[1][1]
)
axs[1][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Top row coverage plot
coverage_plot_top = coverage_probability_plot(
    parameters=out_parameters,
    coverage_probability=mean_proba,
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    title=r'$\mathbf{Local\ diagnostics}$',
    show_text=False,
    custom_ax=axs[0][2]
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Bottom row coverage plot
with open(f'{experiment_dir}/strong_diagnostics_waldo.pkl', 'rb') as f:
    estimator, out_parameters, mean_proba, upper_proba, lower_proba = dill.load(f)
coverage_plot_bottom = coverage_probability_plot(
    parameters=out_parameters,
    coverage_probability=mean_proba,
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    title=None,
    show_text=False,
    custom_ax=axs[1][2]
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Top row set_size_plot
size_plot_top = set_size_plot(
    parameters=params_for_size,
    set_sizes=confset_sizes_waldo,
    param_dim=2,
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    vmin_vmax=(0, colorbar_max),
    title=r'$\mathbf{Local\ size}$',
    params_labels=[r'$\theta_1$', r'$\theta_2$'],
    custom_ax=axs[0][3],
    show_text=False
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Bottom row set_size_plot
size_plot_bottom = set_size_plot(
    parameters=params_for_size,
    set_sizes=confset_sizes,
    param_dim=2,
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    vmin_vmax=(0, colorbar_max),
    params_labels=[r'$\theta_1$', r'$\theta_2$'],
    custom_ax=axs[1][3],
    show_text=False
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Add colorbar for coverage plots (column 3)
cax_coverage_top = fig.add_axes([0.66, 0.55, 0.012, 0.35])
cbar_coverage_top = fig.colorbar(coverage_plot_top, cax=cax_coverage_top)
coverage_ticks = np.linspace(0, 100, num=6)
coverage_tick_labels = [f"{label:.0f}\%" for label in coverage_ticks]
cbar_coverage_top.ax.yaxis.set_ticks(coverage_ticks)
cbar_coverage_top.ax.set_yticklabels(coverage_tick_labels, fontsize=14)
cbar_coverage_top.ax.yaxis.set_ticks_position('right')
cbar_coverage_top.ax.yaxis.set_label_position('right')
cbar_coverage_top.set_label('Coverage probability', rotation=270, labelpad=20, fontsize=18)

cax_coverage_bottom = fig.add_axes([0.66, 0.10, 0.012, 0.35])
cbar_coverage_bottom = fig.colorbar(coverage_plot_bottom, cax=cax_coverage_bottom)
cbar_coverage_bottom.ax.yaxis.set_ticks(coverage_ticks)
cbar_coverage_bottom.ax.set_yticklabels(coverage_tick_labels, fontsize=14)
cbar_coverage_bottom.ax.yaxis.set_ticks_position('right')
cbar_coverage_bottom.ax.yaxis.set_label_position('right')
cbar_coverage_bottom.set_label('Coverage probability', rotation=270, labelpad=20, fontsize=18)

# Add colorbar for size plots (column 4)
cax_size_top = fig.add_axes([0.93, 0.55, 0.012, 0.35])
cbar_size_top = fig.colorbar(size_plot_top, cax=cax_size_top)
standard_ticks = np.round(np.linspace(0, colorbar_max, num=6), 1)
tick_labels = [f"{label:.0f}\%" for label in standard_ticks]
cbar_size_top.ax.yaxis.set_ticks(standard_ticks)
cbar_size_top.ax.set_yticklabels(tick_labels, fontsize=14)
cbar_size_top.ax.yaxis.set_ticks_position('right')
cbar_size_top.ax.yaxis.set_label_position('right')
cbar_size_top.set_label('Percentage of $\\text{Vol}(\Theta)$', rotation=270, labelpad=20, fontsize=18)

cax_size_bottom = fig.add_axes([0.93, 0.10, 0.012, 0.35])
cbar_size_bottom = fig.colorbar(size_plot_bottom, cax=cax_size_bottom)
cbar_size_bottom.ax.yaxis.set_ticks(standard_ticks)
cbar_size_bottom.ax.set_yticklabels(tick_labels, fontsize=14)
cbar_size_bottom.ax.yaxis.set_ticks_position('right')
cbar_size_bottom.ax.yaxis.set_label_position('right')
cbar_size_bottom.set_label('Percentage of $\\text{Vol}(\Theta)$', rotation=270, labelpad=20, fontsize=18)

plt.savefig(f'{experiment_dir}/freb_v_waldo_coverage_size.png', dpi=100)
plt.savefig(f'{experiment_dir}/freb_v_waldo_coverage_size.pdf', dpi=100)
plt.show()

# Uniform prior

In [None]:
experiment_dir = 'results/fmpe/uniform_prior/2025-11-05'

with open(f'{experiment_dir}/obs_x_theta.pkl', 'rb') as f:
    examples = dill.load(f)
    true_theta = examples['true_theta']
    obs_x = examples['obs_x']
# with open(f'{experiment_dir}/lf2i_uniform_prior.pkl', 'rb') as f:
#     lf2i = dill.load(f)

PRIOR_VAR = 0.1 # (6*np.sqrt(2.0))**2
POI_BOUNDS = {r'$\theta_1$': (-1, 1), r'$\theta_2$': (-1, 1)}
PRIOR = MultivariateNormal(
    loc=torch.Tensor([0.5, 0.5]), covariance_matrix=PRIOR_VAR*torch.eye(n=POI_DIM)
)

size_grid_for_sizes = 1_000
with open(f'{experiment_dir}/confidence_sets_for_size_waldo.pkl', 'rb') as f:
    confidence_sets_for_size_waldo = dill.load(f)
    confset_sizes_waldo = np.array([100*cs.shape[0]/size_grid_for_sizes for cs in confidence_sets_for_size_waldo[0]])
with open(f'{experiment_dir}/confidence_sets_for_size.pkl', 'rb') as f:
    confidence_sets_for_size = dill.load(f)
    confset_sizes = np.array([100*cs.shape[0]/size_grid_for_sizes for cs in confidence_sets_for_size[0]])
with open(f'{experiment_dir}/set_for_size.pkl', 'rb') as f:
    set_for_size = dill.load(f)
    params_for_size = set_for_size['params']
    samples_for_size = set_for_size['samples']
colorbar_max = 50

with open(f'{experiment_dir}/confidence_sets.pkl', 'rb') as f:
    confidence_sets = dill.load(f)
with open(f'{experiment_dir}/confidence_sets_waldo.pkl', 'rb') as f:
    confidence_setsw = dill.load(f)
with open(f'{experiment_dir}/uniform_diagnostics_posterior.pkl', 'rb') as f:
    estimator, out_parameters, mean_proba, upper_proba, lower_proba = dill.load(f)

idx_1 = 6
idx_2 = 3

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.gridspec import GridSpec
import numpy as np

fig = plt.figure(figsize=(20, 10))
# Adjusted gridspec with more space after third column
gs = GridSpec(2, 5, figure=fig, left=0.08, top=0.91, bottom=0.12, right=0.92, 
              hspace=0.35, wspace=0.4, width_ratios=[1, 1, 1, 0.06, 1])

# Create axes from gridspec
axs = [[fig.add_subplot(gs[i, j]) for j in [0, 1, 2, 4]] for i in range(2)]

# Add gainsboro patches behind each row (non-overlapping)
# Top row patch
fig.patches.append(patches.Rectangle(
    (0.03, 0.5), 0.967, 0.47, 
    transform=fig.transFigure, 
    edgecolor='black', 
    linewidth=2, 
    facecolor="gainsboro", 
    zorder=-1
))

# Bottom row patch
fig.patches.append(patches.Rectangle(
    (0.03, 0.03), 0.967, 0.45, 
    transform=fig.transFigure, 
    edgecolor='black', 
    linewidth=2, 
    facecolor="gainsboro", 
    zorder=-1
))

idx_obs = idx_1
plot_parameter_regions(
    *[confidence_setsw[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    # true_parameter=params_for_size[idx_obs, :],
    true_parameter = true_theta[idx_obs, :],
    prior_samples=torch.stack(torch.meshgrid(torch.linspace(-100, 100, 101), torch.linspace(-100, 100, 101)), dim=-1).reshape(-1, 2),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        '#64B5F6', '#1976D2'
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=r'$\mathbf{First\ mode}$',
    custom_ax=axs[0][0]
)

plot_parameter_regions(
    *[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    # true_parameter=params_for_size[idx_obs, :],
    true_parameter = true_theta[idx_obs, :],
    prior_samples=torch.stack(torch.meshgrid(torch.linspace(-100, 100, 101), torch.linspace(-100, 100, 101)), dim=-1).reshape(-1, 2),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen',
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=None,
    custom_ax=axs[1][0]
)

idx_obs = idx_2
# true_parameter=params_for_size[idx_obs, :]
true_parameter = true_theta[idx_obs, :]
plot_parameter_regions(
    *[confidence_setsw[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    prior_samples=torch.stack(torch.meshgrid(torch.linspace(-100, 100, 101), torch.linspace(-100, 100, 101)), dim=-1).reshape(-1, 2),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        '#64B5F6', '#1976D2'
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=r'$\mathbf{Second\ mode}$',
    custom_ax=axs[0][1]
)
axs[0][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

plot_parameter_regions(
    *[confidence_sets[j][idx_obs] for j in range(len(CONFIDENCE_LEVEL))],
    param_dim=2,
    prior_samples=torch.stack(torch.meshgrid(torch.linspace(-100, 100, 101), torch.linspace(-100, 100, 101)), dim=-1).reshape(-1, 2),
    parameter_space_bounds={
        r'$\theta_1$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_1$'])), 
        r'$\theta_2$': dict(zip(['low', 'high'], POI_BOUNDS[r'$\theta_2$'])), 
    },
    colors=[
        'teal', 'mediumseagreen',
    ],
    region_names=[
        *[f'FreB {int(cl*100):.0f}\%' for cl in CONFIDENCE_LEVEL],
    ],
    labels=[r'$\theta_1$', r'$\theta_2$'],
    linestyles=['-', '--'],
    param_names=[r'$\theta_1$', r'$\theta_2$'],
    alpha_shape=True,
    alpha=9,
    scatter=False,
    remove_legend=True,
    title=None,
    custom_ax=axs[1][1]
)
axs[1][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][1].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Top row coverage plot
coverage_plot_top = coverage_probability_plot(
    parameters=out_parameters,
    coverage_probability=mean_proba,
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    title=r'$\mathbf{Local\ diagnostics}$',
    show_text=False,
    custom_ax=axs[0][2]
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Bottom row coverage plot
with open(f'{experiment_dir}/uniform_diagnostics_waldo.pkl', 'rb') as f:
    estimator, out_parameters, mean_proba, upper_proba, lower_proba = dill.load(f)
coverage_plot_bottom = coverage_probability_plot(
    parameters=out_parameters,
    coverage_probability=mean_proba,
    confidence_level=CONFIDENCE_LEVEL[0],
    param_dim=2,
    vmin_vmax=(0, 100),
    params_labels=(r'$\theta_1$', r'$\theta_2$'),
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    title=None,
    show_text=False,
    custom_ax=axs[1][2]
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][2].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Top row set_size_plot
size_plot_top = set_size_plot(
    parameters=params_for_size,
    set_sizes=confset_sizes_waldo,
    param_dim=2,
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    vmin_vmax=(0, colorbar_max),
    title=r'$\mathbf{Local\ size}$',
    params_labels=[r'$\theta_1$', r'$\theta_2$'],
    custom_ax=axs[0][3],
    show_text=False
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[0][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Bottom row set_size_plot
size_plot_bottom = set_size_plot(
    parameters=params_for_size,
    set_sizes=confset_sizes,
    param_dim=2,
    xlims=POI_BOUNDS[r'$\theta_1$'],
    ylims=POI_BOUNDS[r'$\theta_2$'],
    vmin_vmax=(0, colorbar_max),
    params_labels=[r'$\theta_1$', r'$\theta_2$'],
    custom_ax=axs[1][3],
    show_text=False
)

idx_obs = idx_1
true_parameter = true_theta[idx_obs, :]
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='red', s=300, linewidth=2, zorder=10)

idx_obs = idx_2
true_parameter = true_theta[idx_obs, :]
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='white', edgecolor='white', s=300, linewidth=2, zorder=10)
axs[1][3].scatter(x=true_parameter.reshape(-1,)[0], y=true_parameter.reshape(-1,)[1],
    alpha=1, marker='*', facecolor='none', edgecolor='blue', s=300, linewidth=2, zorder=10)

# Add colorbar for coverage plots (column 3)
cax_coverage_top = fig.add_axes([0.66, 0.55, 0.012, 0.35])
cbar_coverage_top = fig.colorbar(coverage_plot_top, cax=cax_coverage_top)
coverage_ticks = np.linspace(0, 100, num=6)
coverage_tick_labels = [f"{label:.0f}\%" for label in coverage_ticks]
cbar_coverage_top.ax.yaxis.set_ticks(coverage_ticks)
cbar_coverage_top.ax.set_yticklabels(coverage_tick_labels, fontsize=14)
cbar_coverage_top.ax.yaxis.set_ticks_position('right')
cbar_coverage_top.ax.yaxis.set_label_position('right')
cbar_coverage_top.set_label('Coverage probability', rotation=270, labelpad=20, fontsize=18)

cax_coverage_bottom = fig.add_axes([0.66, 0.10, 0.012, 0.35])
cbar_coverage_bottom = fig.colorbar(coverage_plot_bottom, cax=cax_coverage_bottom)
cbar_coverage_bottom.ax.yaxis.set_ticks(coverage_ticks)
cbar_coverage_bottom.ax.set_yticklabels(coverage_tick_labels, fontsize=14)
cbar_coverage_bottom.ax.yaxis.set_ticks_position('right')
cbar_coverage_bottom.ax.yaxis.set_label_position('right')
cbar_coverage_bottom.set_label('Coverage probability', rotation=270, labelpad=20, fontsize=18)

# Add colorbar for size plots (column 4)
cax_size_top = fig.add_axes([0.93, 0.55, 0.012, 0.35])
cbar_size_top = fig.colorbar(size_plot_top, cax=cax_size_top)
standard_ticks = np.round(np.linspace(0, colorbar_max, num=6), 1)
tick_labels = [f"{label:.0f}\%" for label in standard_ticks]
cbar_size_top.ax.yaxis.set_ticks(standard_ticks)
cbar_size_top.ax.set_yticklabels(tick_labels, fontsize=14)
cbar_size_top.ax.yaxis.set_ticks_position('right')
cbar_size_top.ax.yaxis.set_label_position('right')
cbar_size_top.set_label('Percentage of $\\text{Vol}(\Theta)$', rotation=270, labelpad=20, fontsize=18)

cax_size_bottom = fig.add_axes([0.93, 0.10, 0.012, 0.35])
cbar_size_bottom = fig.colorbar(size_plot_bottom, cax=cax_size_bottom)
cbar_size_bottom.ax.yaxis.set_ticks(standard_ticks)
cbar_size_bottom.ax.set_yticklabels(tick_labels, fontsize=14)
cbar_size_bottom.ax.yaxis.set_ticks_position('right')
cbar_size_bottom.ax.yaxis.set_label_position('right')
cbar_size_bottom.set_label('Percentage of $\\text{Vol}(\Theta)$', rotation=270, labelpad=20, fontsize=18)

plt.savefig(f'{experiment_dir}/freb_v_waldo_coverage_size.png', dpi=100)
plt.savefig(f'{experiment_dir}/freb_v_waldo_coverage_size.pdf', dpi=100)
plt.show()