In [None]:
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit
import jax.random as random
from functools import partial
from typing import Callable, Union
from collections import namedtuple
import numpy as np
from scipy.stats import gaussian_kde

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.cm as cm
from matplotlib.lines import Line2D
from matplotlib.ticker import FormatStrFormatter
import matplotlib.ticker as ticker
from mpl_toolkits.axes_grid1.inset_locator import zoomed_inset_axes, mark_inset
import dill

from lib.kernel import imq_kernel, rbf_kernel
from lib.model import model
from lib.methods import VGD
from lib.experiment import experiment, diagnostic_experiment
from lib.calculate_mmd import calculate_mmd_squared
from lib.plot_functions import *


def quadratic_reg(theta, x):
    return theta * x**2

quad_model = model(0.5, quadratic_reg, theta_dim=1)

def my_sigmoid(x):
    return 1 / (1 + jnp.exp(-x))

def sigmoid_reg(theta, x):
    return 1 / (1 + jnp.exp(-theta * x))

sig_model = model(0.05, sigmoid_reg, theta_dim=1)

def linear_reg(theta, x):
    return theta[0] + theta[1] * x

linear_model = model(0.8, linear_reg, theta_dim=2)

In [None]:
data_quad_w = quad_model.generate_data(1000, 5, x_min=0.0, x_max=1.0, key=random.PRNGKey(12))
data_sig_w = sig_model.generate_data(1000, 5, x_min=-1.0, x_max=1.0, key=random.PRNGKey(34))

key = random.PRNGKey(12)
rng, key = random.split(key)
n_data = 1000
xs = random.uniform(rng, (n_data,))
true_theta = 5.
sigma = 0.5
phi = 3
rng, key = random.split(key)
ys = (true_theta + phi * random.normal(rng, xs.shape)) * jnp.square(xs)
rng, key = random.split(key)
ys += sigma * random.normal(rng, xs.shape)
data_quad_m = (xs, ys)

def sigmoid(z):
  return 1 / (1 + jnp.exp(-z))
key = jax.random.PRNGKey(42) 
n_data = 1000
key, x_key = random.split(key)
x = random.uniform(x_key, (n_data,), minval=-1.0, maxval=1.0)
key, y_neg_key, y_pos_key = random.split(key, 3)
y_neg = random.uniform(y_neg_key, (n_data,), minval=0.0, maxval=0.5)
y_pos = random.uniform(y_pos_key, (n_data,), minval=0.5, maxval=1.0)
y_observed = jnp.where(x < 0, y_neg, y_pos)
data_sig_m = (x, y_observed)


true_theta = jnp.array([5.0, 3.0]) 
prior_mean = jnp.array([0.0, 0.0])
key = random.PRNGKey(12)
rng, key = random.split(key)
n_data = 1000
xs = random.uniform(rng, (n_data,), minval=-2.0, maxval=2.0)
sigma = 0.8
rng, key = random.split(key)
ys = true_theta[0] + true_theta[1] * xs
ys += sigma * random.normal(rng, xs.shape)
data_linear_w = (xs, ys)

true_theta = jnp.array([5.0, 3.0]) 
prior_mean = jnp.array([0.0, 0.0])
key = random.PRNGKey(12)
rng, key = random.split(key)
n_data = 1000
xs = random.uniform(rng, (n_data,), minval=-2.0, maxval=2.0) 
sigma = 0.8 
rng, key = random.split(key)
ys = true_theta[0] + true_theta[1] * xs + 2 * jnp.square(xs)
ys += sigma * random.normal(rng, xs.shape)
data_linear_m = (xs, ys)

## quad

In [None]:
experiment_quad_w = experiment(quad_model, data_quad_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
# experiment_quad_w.run(n_steps=5000, step_size=0.001, lengthscale=0.1)
experiment_quad_w.run(n_steps=5000, step_size=0.001)

In [None]:
experiment_quad_w.plot_KSD()

In [None]:
diagnostic_experiment_quad_w = diagnostic_experiment(experiment_quad_w)
all_mmd_values_quad_w, actual_mmd_quad_w = diagnostic_experiment_quad_w.plot_diagnostic()

In [None]:
experiment_quad_m = experiment(quad_model, data_quad_m, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
experiment_quad_m.run(n_steps=5000, step_size=0.0005, lengthscale=0.1)

In [None]:
experiment_quad_m.plot_KSD()

In [None]:
diagnostic_experiment_quad_m = diagnostic_experiment(experiment_quad_m)
all_mmd_values_quad_m, actual_mmd_quad_m = diagnostic_experiment_quad_m.plot_diagnostic()

In [None]:
plot_predictives(experiment_quad_w, experiment_quad_m)

## sigmoid

In [None]:
experiment_sig_w = experiment(sig_model, data_sig_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
# experiment_sig_w.run(n_steps=10000, step_size=0.001, lengthscale=0.1)
experiment_sig_w.run(n_steps=10000, step_size=0.005)

In [None]:
experiment_sig_w.plot_KGD()

In [None]:
experiment_sig_w.particles_VGD.mean(axis=0)

In [None]:
diagnostic_experiment_sig_w = diagnostic_experiment(experiment_sig_w)
all_mmd_values_sig_w, actual_mmd_sig_w = diagnostic_experiment_sig_w.plot_diagnostic()

In [None]:
experiment_sig_m = experiment(sig_model, data_sig_m, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
experiment_sig_m.run(n_steps=5000, step_size=0.0005, lengthscale=0.1)

In [None]:
experiment_sig_m.plot_KGD()

In [None]:
experiment_sig_m.plot_KSD()

In [None]:
diagnostic_experiment_sig_m = diagnostic_experiment(experiment_sig_m)
all_mmd_values_sig_m, actual_mmd_sig_m = diagnostic_experiment_sig_m.plot_diagnostic()

In [None]:
experiment_sig_m.particles_VGD

In [None]:
plot_predictives(experiment_sig_w, experiment_sig_m)

## linear and quadratic

In [None]:
experiment_landq_w = experiment(linear_model, data_linear_w, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
# experiment_landq_w.run(n_steps=10000, step_size=0.0005, lengthscale=0.1)
experiment_landq_w.run(n_steps=10000, step_size=0.0005)

In [None]:
experiment_landq_w.plot_KGD()

In [None]:
diagnostic_experiment_landq_w = diagnostic_experiment(experiment_landq_w)
all_mmd_values_landq_w, actual_mmd_landq_w = diagnostic_experiment_landq_w.plot_diagnostic()

In [None]:
experiment_landq_m = experiment(linear_model, data_linear_m, n_particles=20, kernel=imq_kernel, key=random.PRNGKey(11))
# experiment_landq_m.run(n_steps=10000, step_size=0.0005, lengthscale=0.1)
experiment_landq_m.run(n_steps=10000, step_size=0.0005)

In [None]:
experiment_landq_m.plot_KSD()

In [None]:
diagnostic_experiment_landq_m = diagnostic_experiment(experiment_landq_m)
all_mmd_values_landq_m, actual_mmd_landq_m = diagnostic_experiment_landq_m.plot_diagnostic()

In [None]:
plot_predictives(experiment_landq_w, experiment_landq_m)

In [None]:
data_to_save = {
    'experiment_quad_w': experiment_quad_w,
    'experiment_quad_m': experiment_quad_m,
    'all_mmd_values_quad_w': all_mmd_values_quad_w,
    'actual_mmd_quad_w': actual_mmd_quad_w,
    'all_mmd_values_quad_m': all_mmd_values_quad_m,
    'actual_mmd_quad_m': actual_mmd_quad_m,
    'experiment_sig_w': experiment_sig_w,
    'experiment_sig_m': experiment_sig_m,
    'all_mmd_values_sig_w': all_mmd_values_sig_w,
    'actual_mmd_sig_w': actual_mmd_sig_w,
    'all_mmd_values_sig_m': all_mmd_values_sig_m,
    'actual_mmd_sig_m': actual_mmd_sig_m,
    'experiment_landq_w': experiment_landq_w,
    'experiment_landq_m': experiment_landq_m,
    'all_mmd_values_landq_w': all_mmd_values_landq_w,
    'actual_mmd_landq_w': actual_mmd_landq_w,
    'all_mmd_values_landq_m': all_mmd_values_landq_m,
    'actual_mmd_landq_m': actual_mmd_landq_m
}
import dill
with open('main_fig.dill', 'wb') as f:
    dill.dump(data_to_save, f)

In [None]:
try:
    with open('main_fig.dill', 'rb') as f:
        data_to_save = dill.load(f)
    print("加载成功：旧的字典已读入内存。")
except FileNotFoundError:
    print("警告：未找到旧文件。将创建一个新的字典。")
    data_to_save = {}

data_to_save = {
    'experiment_quad_w': experiment_quad_w,
    'experiment_quad_m': experiment_quad_m,
    'all_mmd_values_quad_w': all_mmd_values_quad_w,
    'actual_mmd_quad_w': actual_mmd_quad_w,
    'all_mmd_values_quad_m': all_mmd_values_quad_m,
    'actual_mmd_quad_m': actual_mmd_quad_m,
    'experiment_sig_w': experiment_sig_w,
    'experiment_sig_m': experiment_sig_m,
    'all_mmd_values_sig_w': all_mmd_values_sig_w,
    'actual_mmd_sig_w': actual_mmd_sig_w,
    'all_mmd_values_sig_m': all_mmd_values_sig_m,
    'actual_mmd_sig_m': actual_mmd_sig_m,
    'experiment_landq_w': experiment_landq_w,
    'experiment_landq_m': experiment_landq_m,
    'all_mmd_values_landq_w': all_mmd_values_landq_w,
    'actual_mmd_landq_w': actual_mmd_landq_w,
    'all_mmd_values_landq_m': all_mmd_values_landq_m,
    'actual_mmd_landq_m': actual_mmd_landq_m
}
with open('main_fig.dill', 'wb') as f:
    dill.dump(data_to_save, f)

In [None]:
with open('main_fig.dill', 'rb') as f:
    loaded_data = dill.load(f)

In [None]:
loaded_data.keys()

## main figure

In [None]:
def plot_main_figure(file_name='main_fig.dill'):
    sns.set_theme(
    style="white",
    rc={
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman", "CMU Serif", "Times New Roman"],
        "mathtext.fontset": "cm",
        "legend.frameon": True
    }
    )
    # -----------------------------------------------------------

    jax.config.update("jax_enable_x64", True)
    try:
        with open(file_name, 'rb') as f:
            data = dill.load(f)
        print(f"Successfully loaded file containing {len(data.keys())} arrays: {list(data.keys())}")
    except FileNotFoundError:
        print(fr"Error, '{file_name}' not found.")
        exit()

    col_widths = [1, 1, 1, 1, 1, 0.75 + 0.25] 
    row_heights = [1, 1, 1]
    nested_col_widths = [0.75, 0.25]

    fig = plt.figure(figsize=(24, 12.5))
    main_gs = fig.add_gridspec(
        nrows=3, 
        ncols=6, 
        height_ratios=row_heights,
        width_ratios=col_widths
    )
    axes = np.empty((3, 7), dtype=object)
    for r in range(3):
        for c in range(5):
            axes[r, c] = fig.add_subplot(main_gs[r, c])

        sub_gs_spec = main_gs[r, 5]
        
        nested_gs = sub_gs_spec.subgridspec(
            1, 2, 
            width_ratios=nested_col_widths, 
            wspace=0.05 
        )
        
        axes[r, 5] = fig.add_subplot(nested_gs[0, 0])
        axes[r, 6] = fig.add_subplot(nested_gs[0, 1])

        
    data_prefixes = ['quad', 'sig', 'landq']
    # row_titles = ['Quadratic', 'Sigmoid', '2D Quadratic']
    row_titles = [r'$\mathrm{Quadratic}$', r'$\mathrm{Sigmoid}$', r'$\mathrm{Linear (d=2)}$']
    VGD_color = '#ff7f0e'
    SVGD_color = '#1f77b4'

    for row_idx, prefix in enumerate(data_prefixes):
        experiment_w = data[f'experiment_{prefix}_w']

        all_mmd_values_w = data[f'all_mmd_values_{prefix}_w']
        actual_mmd_w = data[f'actual_mmd_{prefix}_w'].item()

        plot_shaded_region_predictive(axes[row_idx, 0], experiment_w, experiment_w.particles_SVGD, SVGD_color)
        plot_shaded_region_predictive(axes[row_idx, 1], experiment_w, experiment_w.particles_VGD, VGD_color)
        axes[row_idx, 1].yaxis.set_ticks_position('none')
        plt.setp(axes[row_idx, 1].get_yticklabels(), visible=False)

        ax = axes[row_idx, 2]
        plot_diagnostic(ax, all_mmd_values=all_mmd_values_w, actual_mmd=actual_mmd_w)

        experiment_m = data[f'experiment_{prefix}_m']

        all_mmd_values_m = data[f'all_mmd_values_{prefix}_m']
        actual_mmd_m = data[f'actual_mmd_{prefix}_m'].item()

        plot_shaded_region_predictive(axes[row_idx, 3], experiment_m, experiment_m.particles_SVGD, SVGD_color)
        plot_shaded_region_predictive(axes[row_idx, 4], experiment_m, experiment_m.particles_VGD, VGD_color)
        axes[row_idx, 4].yaxis.set_ticks_position('none')
        plt.setp(axes[row_idx, 4].get_yticklabels(), visible=False)

        ax_left = axes[row_idx, 5]
        ax_right = axes[row_idx, 6]
        ax_right.sharey(ax_left)
        if row_idx == 0:
            xlim_left = (0.0, 1.6e-5)
            xticks_left = [0.0, 1e-5]
        elif row_idx == 2:
            xlim_left = (0.0, 1.2e-5)
            xticks_left = [0.0, 8e-6]
        else:
            xlim_left = (0.0, 1e-5)
            xticks_left = [0.0, 7e-6]
        plot_diagnostic_manual_broken(
            ax_left, 
            ax_right, 
            all_mmd_values=all_mmd_values_m, 
            actual_mmd=actual_mmd_m,
            width_ratios=[0.9, 0.3],
            xlim_left=xlim_left,
            xticks_left=xticks_left
        )

    plt.tight_layout(rect=[0.03, 0, 1, 0.94])

    col_titles = [r'$P_\mathrm{Bayes}$', r'$P_\mathrm{PrO}$', r'$\mathrm{MMD}$', r'$P_\mathrm{Bayes}$', r'$P_\mathrm{PrO}$', r'$\mathrm{MMD}$']
    for i, title in enumerate(col_titles):
        if title:
            if i == 5: 
                x0 = axes[0, 5].get_position().x0
                x1 = axes[0, 6].get_position().x1
                x_coord = (x0 + x1) / 2
            else:
                x_coord = (axes[0, i].get_position().x0 + axes[0, i].get_position().x1) / 2
            
            fig.text(x_coord, 0.96, title, ha='center', va='top', fontsize=28)

    for i, title in enumerate(row_titles):
        y_coord = (axes[i, 0].get_position().y0 + axes[i, 0].get_position().y1) / 2
        fig.text(0.02, y_coord, title, ha='left', va='center', fontsize=28, rotation=90)


    fig.subplots_adjust(
        wspace=0.16,  
        hspace=0.12   
    )
    plt.show()

In [None]:
plot_main_figure()

In [None]:
12e-6

## trajectory figure

In [None]:
def plot_trajectory_2d(particle_history, true_theta, n_paths_to_plot=20, 
                         base_color='orange', gradient_color=True, ax=None):

    particle_history_np = np.array(particle_history)
    true_theta_np = np.array(true_theta)

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
        show_plot = True
    else:
        show_plot = False 

    sns.set_theme(style="whitegrid")
    
    # take out the trajectories to plot
    paths = particle_history_np[:, :n_paths_to_plot, :]

    # gradient colour option
    if gradient_color:
        colors = sns.light_palette(base_color, n_colors=n_paths_to_plot)
    else:
        colors = [base_color] * n_paths_to_plot

    # plot each trajectory
    for i in range(n_paths_to_plot):
        ax.plot(paths[:, i, 0], paths[:, i, 1], color=colors[i], alpha=0.7, linewidth=1.5)

    # final points
    ax.scatter(paths[-1, :, 0], paths[-1, :, 1], 
               c=colors, s=50, zorder=5,
               edgecolor='black', linewidth=0.5, label='Final Positions')

    # starting points
    ax.scatter(paths[0, :, 0], paths[0, :, 1], 
               facecolors='none', edgecolors=colors, s=50, zorder=5,
               linewidth=1.5, label='Start Positions')
    
    ## true parameter position
    # ax.scatter(true_theta_np[0], true_theta_np[1], c='red', s=100, marker='*', zorder=6, label='True Value')

    ax.set_xlabel(r'$\theta_1$', fontsize=20)
    ax.set_ylabel(r'$\theta_2$', fontsize=20)
    # ax.set_title(f'2D Particle Trajectories')
    # ax.legend()
    ax.axis('equal')

    if show_plot:
        plt.show()


In [None]:
with open('main_fig.dill', 'rb') as f:
    data = dill.load(f)

sns.set_theme(
    style="white",
    rc={
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman", "CMU Serif", "Times New Roman"],
        "mathtext.fontset": "cm",
        "legend.frameon": True
    }
    )

fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(28, 7))
VGD_color = '#ff7f0e'
SVGD_color = '#1f77b4'
true_theta = jnp.array([5.0, 3.0])

plot_trajectory_2d(data['experiment_landq_w'].history_SVGD, true_theta, n_paths_to_plot=20, 
                     base_color=SVGD_color, gradient_color=False, ax=axes[0])
axes[0].set_title(r'$Q_\mathrm{Bayes}$', fontsize=34, pad=20) 

plot_trajectory_2d(data['experiment_landq_w'].history_VGD, true_theta, n_paths_to_plot=20, 
                     base_color=VGD_color, gradient_color=False, ax=axes[1])
axes[1].set_title( r'$Q_{\mathrm{PrO}}$', fontsize=34, pad=20)
plot_trajectory_2d(data['experiment_landq_m'].history_SVGD, true_theta, n_paths_to_plot=20, 
                     base_color=SVGD_color, gradient_color=False, ax=axes[2])
axes[2].set_title(r'$Q_\mathrm{Bayes}$', fontsize=34, pad=20)

plot_trajectory_2d(data['experiment_landq_m'].history_VGD, true_theta, n_paths_to_plot=20, 
                     base_color=VGD_color, gradient_color=False, ax=axes[3])
axes[3].set_title(r'$Q_{\mathrm{PrO}}$', fontsize=34, pad=20)

for ax in axes:
    ax.set(xlabel=None, ylabel=None)
    ax.tick_params(axis='both', which='major', labelsize=26)
# fig.suptitle('Comparison of Trajectory Plotting Styles', fontsize=16)

plt.tight_layout(rect=[0, 0, 1, 1]) 
plt.show()

In [None]:
with open('main_fig.dill', 'rb') as f:
    data = dill.load(f)


In [None]:
data['experiment_landq_w'].particles_SVGD.mean(axis=0)

## posteriors and KGDs figure

In [None]:
VGD_colour = '#ff7f0e'
SVGD_colour = '#1f77b4'

fontsize_axis = 20

def plot_KGD(ax, experiment: experiment):
    ax.plot(range(len(experiment.history_KGD)), jnp.log(experiment.history_KGD), label='Log KGD', color=VGD_colour)
    ax.plot(range(len(experiment.history_KSD)), jnp.log(experiment.history_KSD), label='Log KSD', color=SVGD_colour)
    ax.set(xlabel=None, ylabel=None)
    ax.tick_params(axis='both', labelsize=fontsize_axis)


def plot_kde_q(ax, experiment: experiment, ylims=None, misspecified=False):
    x1 = jnp.array(experiment.particles_SVGD).flatten()
    x2 = jnp.array(experiment.particles_VGD).flatten()
    

    def generate_plot_data(x_grid, y_vals, valid_indices):
        if len(valid_indices) > 0:
            start_idx = valid_indices[0]
            end_idx = valid_indices[-1]
            padding = 5
            start_idx = max(0, start_idx - padding)
            end_idx = min(len(x_grid) - 1, end_idx + padding)

            x_plot = x_grid[start_idx : end_idx]
            y_plot = y_vals[start_idx : end_idx]
            return x_plot, y_plot
        else:
            return x_grid, y_vals
    if misspecified:
        kde1 = gaussian_kde(x1)
        kde2 = gaussian_kde(x2)
        x_grid = np.linspace(x2.min()-10, x2.max()+10, 100000)
        y_vals1 = kde1(x_grid)
        y_vals2 = kde2(x_grid)
        peak_height1 = np.max(y_vals1)
        peak_height2 = np.max(y_vals2)
        relative_threshold1 = peak_height1 * 1e-12
        relative_threshold2 = peak_height2 * 1e-12
        valid_indices1 = np.where(y_vals1 > relative_threshold1)[0]
        valid_indices2 = np.where(y_vals2 > relative_threshold2)[0]
        x1_plot, y1_plot = generate_plot_data(x_grid, y_vals1, valid_indices1)
        x2_plot, y2_plot = generate_plot_data(x_grid, y_vals2, valid_indices2)
    
        ax.plot(
            x1_plot, 
            y1_plot, 
            label='Q_Bayes', 
            color=SVGD_colour, 
            linewidth=1.5
        )
        ax.plot(
            x2_plot, 
            y2_plot, 
            label='Q_PrO', 
            color=VGD_colour, 
            linewidth=1.5
        )
    else:
        sns.kdeplot(
            x=x1, 
            fill=False, 
            label='Q_Bayes', 
            ax=ax, 
            color=SVGD_colour,
        )
        sns.kdeplot(
            x=x2, 
            fill=False, 
            label='Q_PrO', 
            ax=ax, 
            color=VGD_colour,
        )
    ax.set(xlabel=None, ylabel=None)
    ax.tick_params(axis='both', labelsize=fontsize_axis)
    
    if ylims is not None:
        ax.set_ylim(ylims)
    # ax.set_yticklabels([])
    # ax.get_yaxis().set_major_formatter(plt.FormatStrFormatter('%.2f'))

    # ax.legend()

def plot_kde_q_2d(ax, experiment: experiment, xlims=[-4,4], ylims=[-4,4], bw_adjust=1, zoomin= False):
    x_SVGD = experiment.particles_SVGD[:, 0]
    y_SVGD = experiment.particles_SVGD[:, 1]
    sns.kdeplot(
        x=x_SVGD, 
        y=y_SVGD, 
        thresh=0.05, 
        levels=8, 
        ax=ax, 
        color=SVGD_colour,
        bw_adjust=bw_adjust,
        fill=False
    )
    x_VGD = experiment.particles_VGD[:, 0]
    y_VGD = experiment.particles_VGD[:, 1]
    sns.kdeplot(
        x=x_VGD, 
        y=y_VGD, 
        thresh=0.05, 
        levels=8, 
        ax=ax, 
        color=VGD_colour,
        bw_adjust=bw_adjust,
        fill=False
    )
    ax.set_xlim(xlims)
    ax.set_ylim(ylims)
    handles = [Line2D([0], [0], color=SVGD_colour, lw=2),
                   Line2D([0], [0], color=VGD_colour, lw=2)]
    labels = [r'$Q_{Bayes}$', r'$Q_{PC}$']
    ax.set(xlabel=None, ylabel=None)
    ax.tick_params(axis='both', labelsize=fontsize_axis)
    if zoomin:
        axins = zoomed_inset_axes(ax, zoom=16, loc='upper left')            # (1=upper right, 2=upper left, 3=lower left, 4=lower right)
        x1, x2 = 7.45, 7.85   
        y1, y2 = 2.75, 3.25   
        sns.kdeplot(
            x=x_SVGD, 
            y=y_SVGD, 
            thresh=0.05, 
            levels=8, 
            ax=axins, 
            color=SVGD_colour,
            bw_adjust=bw_adjust,
            fill=False
        )
        sns.kdeplot(
            x=x_VGD, 
            y=y_VGD, 
            thresh=0.05, 
            levels=8, 
            ax=axins, 
            color=VGD_colour,
            bw_adjust=bw_adjust,
            fill=False
        )
        axins.set_xlim(x1, x2)
        axins.set_ylim(y1, y2)
        axins.set_xticklabels([])
        axins.set_yticklabels([])
        mark_inset(ax, axins, loc1=1, loc2=4, fc="none", ec="0.5")

with open('main_fig.dill', 'rb') as f:
    data = dill.load(f)

sns.set_theme(
    style="white",
    rc={
        "font.family": "serif",
        "font.serif": ["Computer Modern Roman", "CMU Serif", "Times New Roman"],
        "mathtext.fontset": "cm",
        "legend.frameon": True
    }
    )

fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(20, 15))


data_prefixes = ['quad', 'sig', 'landq']
row_titles = [r'$\mathrm{Quadratic}$', r'$\mathrm{Sigmoid}$', r'$\mathrm{Linear (d=2)}$']

for row_idx, prefix in enumerate(data_prefixes[0:2]):
    experiment_w = data[f'experiment_{prefix}_w']

    plot_kde_q(axes[row_idx, 0], experiment_w)
    plot_KGD(axes[row_idx, 1], experiment_w)

    experiment_m = data[f'experiment_{prefix}_m']
    if prefix == 'quad':
        ylims = [0, 0.16]
    else:
        ylims = [0, 0.1]
    plot_kde_q(axes[row_idx, 2], experiment_m, ylims=ylims, misspecified= True)
    plot_KGD(axes[row_idx, 3], experiment_m)

experiment_w = data['experiment_landq_w']
plot_kde_q_2d(axes[2, 0], experiment_w, xlims=[4.75,5.25], ylims=[2.7,3.3])
plot_KGD(axes[2, 1], experiment_w)

experiment_m = data['experiment_landq_m']
plot_kde_q_2d(axes[2, 2], experiment_m, xlims=[-3,10], ylims=[-6,10], zoomin= True)
plot_KGD(axes[2, 3], experiment_m)


row_titles = [r'$\mathrm{Quadratic}$', r'$\mathrm{Sigmoid}$', r'$\mathrm{Linear (d=2)}$']
col_titles = [r'$\mathrm{Posteriors}$', r'$\log(\mathrm{KGD})$', r'$\mathrm{Posteriors}$', r'$\log(\mathrm{KGD})$']
for i, title in enumerate(col_titles):
    if title:
        x_coord = (axes[0, i].get_position().x0 + axes[0, i].get_position().x1) / 2
        fig.text(x_coord, 0.92, title, ha='center', va='top', fontsize=28)

for i, title in enumerate(row_titles):
    y_coord = (axes[i, 0].get_position().y0 + axes[i, 0].get_position().y1) / 2
    fig.text(0.075, y_coord, title, ha='left', va='center', fontsize=28, rotation=90)
