# Plot

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import os
import csv

In [2]:
replicate_result_folder = 'replicate_results'
if not os.path.exists(replicate_result_folder):
    os.makedirs(replicate_result_folder)

In [3]:
n = 25
d = 5

# color setting
truth_color = "#FF6B6B"
est_color = "#4D96FF"
refined_color = "#6BCB77"
upper_labels=["\\theta_1","\\theta_2","\\theta_3","\\theta_4","\\theta_5"]

In [4]:
nn_ps_samples_path = os.path.join(os.getcwd(), 'result','nn_ps')
nn_mcmc_samples_path = os.path.join(os.getcwd(), 'result','nn_mcmc_samples')

bf_ps_samples_path = os.path.join(os.getcwd(), 'result','bf_ps')
bf_mcmc_samples_path = os.path.join(os.getcwd(), 'result','bf_mcmc_samples')

dnnabc_ps_samples_path = os.path.join(os.getcwd(), 'result','dnnabc_ps')

w2abc_ps_samples_path = os.path.join(os.getcwd(), 'result','w2abc_ps')

In [5]:
true_posterior_path = os.path.join(os.getcwd(), 'data','ps.npy')
true_posterior_samples = np.load(true_posterior_path)
print(true_posterior_samples.shape)
bandwidth_path = os.path.join(os.getcwd(), 'data','h_mmd.npy')
bandwidth = np.load(bandwidth_path)

(5000, 5)


In [6]:
def plot(mcmc_samples,ps_samples,true_posterior_samples, address, model="NN"):
    sns.set_style("whitegrid")
    fig, axs = plt.subplots(1, 5, figsize=(25, 6))
    true_ps = [1, 1, -1.0, -0.9, 0.6]

    # 定义每个theta_i对应的x轴范围
    # x_limits = [
    #     [0.7, 1.3],  # theta_0
    #     [0.6, 1.4],  # theta_1
    #     [-1.5, 1.5],  # theta_2
    #     [-1.5, 1.5],  # theta_3
    #     [0, 1.2],  # theta_4
    # ]
    
    # x_limits = [
    #     [-3.0, 3.0],  # theta_0
    #     [-3.0, 3.0],  # theta_1
    #     [-3.0, 3.0],  # theta_2
    #     [-3.0, 3.0],  # theta_3
    #     [-3.0, 3.0],  # theta_4
    # ]

    x_limits = [
        [0.5, 1.5],  # theta_0
        [0.3, 1.7],  # theta_1
        [-2.0, 2.0],  # theta_2
        [-2.0, 2.0],  # theta_3
        [-0.3, 1.5],  # theta_4
    ]

    for j, ax in enumerate(axs):
        ax.set_xlim(x_limits[j])
        ax.set_xticks(np.linspace(x_limits[j][0], x_limits[j][1], 5))

    for upper_label, j in zip(upper_labels,range(d)):
        sns.kdeplot(
            true_posterior_samples[:, j],
            ax=axs[j],
            fill=False,
            label="posterior",
            color=truth_color,
            linestyle="-.",
            linewidth=1.5,
        )
        sns.kdeplot(
            ps_samples[:, j],
            ax=axs[j],
            fill=False,
            label=f"{model}",
            color=est_color,
            linestyle="-",
            linewidth=1.5,
        )
        sns.kdeplot(
            mcmc_samples[:, j],
            ax=axs[j],
            fill=False,
            label=f"{model}+ABC-MCMC",
            color=refined_color,
            linestyle="--",
            linewidth=1.5,
        )
        axs[j].set_title(f"${upper_label}$", pad=15)
        axs[j].set_ylabel("")

    # save figure
    handles, labels = axs[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=3)
    plt.tight_layout(pad=3.0)
    plt.savefig(address, dpi=300)
    plt.close(fig)

In [7]:
def subplot(mcmc_samples,ps_samples,true_posterior_samples, address, model="NN", element_index=0):
    sns.set_style("whitegrid")
    fig, axs = plt.subplots(1, 1, figsize=(6, 6))
    true_ps = [1, 1, -1.0, -0.9, 0.6]

    # 定义每个theta_i对应的x轴范围
    # x_limits = [
    #     [0.7, 1.3],  # theta_0
    #     [0.6, 1.4],  # theta_1
    #     [-1.5, 1.5],  # theta_2
    #     [-1.5, 1.5],  # theta_3
    #     [0, 1.2],  # theta_4
    # ]

    x_limits = [
        [0.5, 1.5],  # theta_0
        [0.3, 1.7],  # theta_1
        [-2.0, 2.0],  # theta_2
        [-2.0, 2.0],  # theta_3
        [-0.3, 1.5],  # theta_4
    ]


    axs.set_xlim(x_limits[element_index])
    axs.set_xticks(np.linspace(x_limits[element_index][0], x_limits[element_index][1], 5))


    sns.kdeplot(        
        true_posterior_samples[:, element_index],
        ax=axs,
        fill=False,
        label="posterior",
        color=truth_color,
        linestyle="-.",
        linewidth=1.5,
        )
    
    sns.kdeplot(
            ps_samples[:, element_index],
            ax=axs,
            fill=False,
            label=f"{model}",
            color=est_color,
            linestyle="-",
            linewidth=1.5,
        )
    
    sns.kdeplot(
            mcmc_samples[:, element_index],
            ax=axs,
            fill=False,
            label=f"{model}+ABC-MCMC",
            color=refined_color,
            linestyle="--",
            linewidth=1.5,
        )
    
    axs.set_title(f"${upper_labels[element_index]}$", pad=15)
    axs.set_ylabel("")

    # save figure
    handles, labels = axs.get_legend_handles_labels()
    fig.legend(handles, labels, loc="lower center", ncol=3)
    plt.tight_layout(pad=3.0)
    plt.savefig(address, dpi=300)
    plt.close(fig)

In [8]:
def compute_mmd(x_,y_, bandwidth):
    """
    Compute the mean and standard deviation of the MMD between x and y.
    :param x_: (N_x,d)
    :param y_: (N_y,d)
    :return:
    """
    x = tf.convert_to_tensor(x_, dtype=tf.float32)
    y = tf.convert_to_tensor(y_, dtype=tf.float32)

    xx = tf.matmul(x,tf.transpose(x))
    xy = tf.matmul(x,tf.transpose(y))
    yy = tf.matmul(y,tf.transpose(y))

    rx = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
    ry = tf.reduce_sum(tf.square(y), axis=-1, keepdims=True)

    se_xx = rx - 2*xx + tf.transpose(rx)
    se_xy = rx - 2*xy + tf.transpose(ry)
    se_yy = ry - 2*yy + tf.transpose(ry)

    K_xx = tf.exp(-0.5 * se_xx / (bandwidth**2) )
    K_xy = tf.exp( -0.5 * se_xy / (bandwidth**2) )
    K_yy = tf.exp( -0.5 * se_yy / (bandwidth**2) )

    return tf.reduce_mean(K_xx) - 2*tf.reduce_mean(K_xy) + tf.reduce_mean(K_yy)

## Local ABC-MCMC replicate

In [9]:
nn_replicate_mmd_csv = os.path.join(replicate_result_folder, 'nn_replicate_mmd.csv')
with open(nn_replicate_mmd_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Iteration", "MMD","refined MMD"])

bf_replicate_mmd_csv = os.path.join(replicate_result_folder, 'bf_replicate_mmd.csv')

with open(bf_replicate_mmd_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Iteration", "MMD","refined MMD"])

dnnabc_replicate_mmd_csv = os.path.join(replicate_result_folder, 'dnnabc_replicate_mmd.csv')
with open(dnnabc_replicate_mmd_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Iteration", "MMD","refined MMD"])

w2abc_replicate_mmd_csv = os.path.join(replicate_result_folder, 'w2abc_replicate_mmd.csv')
with open(w2abc_replicate_mmd_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Iteration", "MMD","refined MMD"])

nn_fig_folder = os.path.join(replicate_result_folder, 'nn_figures')
if not os.path.exists(nn_fig_folder):
    os.makedirs(nn_fig_folder)

bf_fig_folder = os.path.join(replicate_result_folder, 'bf_figures')
if not os.path.exists(bf_fig_folder):
    os.makedirs(bf_fig_folder)

In [10]:
burn_in =  350 #250
n_samples = 100
thin = 20
refinement_index = []

for i in range(burn_in+n_samples):
    if i >= burn_in and (i-burn_in) % thin == 0:
        refinement_index.append(i)

print(f"Refinement index: {refinement_index}")

Refinement index: [350, 370, 390, 410, 430]


### update MMD 

In [11]:
for k in range(10):
    print(f"Processing iteration {k+1}...")

    nn_kth_ps_path = os.path.join(nn_ps_samples_path, f'nn_{n}_ps_{k}.npy')
    nn_ps_samples = np.load(nn_kth_ps_path)

    kth_mcmc_samples_path = os.path.join(nn_mcmc_samples_path, f'nn_{n}_mcmc_samples_{k}.npy')
    mcmc_samples_total = np.load(kth_mcmc_samples_path)
    mcmc_samples = mcmc_samples_total[refinement_index, :]
    mcmc_samples = mcmc_samples.reshape(-1, d)

    mmd = compute_mmd(nn_ps_samples, true_posterior_samples, bandwidth)
    refined_mmd = compute_mmd(mcmc_samples, true_posterior_samples, bandwidth)

    print(f"MMD: {mmd.numpy()}, Refined MMD: {refined_mmd.numpy()}")
    with open(nn_replicate_mmd_csv, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([k+1, mmd.numpy(), refined_mmd.numpy()])    

Processing iteration 1...
MMD: 0.035132795572280884, Refined MMD: 0.01695062220096588
Processing iteration 2...
MMD: 0.03247249126434326, Refined MMD: 0.02067171037197113
Processing iteration 3...
MMD: 0.054976433515548706, Refined MMD: 0.028363153338432312
Processing iteration 4...
MMD: 0.03751438111066818, Refined MMD: 0.019171930849552155
Processing iteration 5...
MMD: 0.08657106757164001, Refined MMD: 0.05443992465734482
Processing iteration 6...
MMD: 0.025587663054466248, Refined MMD: 0.021399185061454773
Processing iteration 7...
MMD: 0.033127471804618835, Refined MMD: 0.0167626291513443
Processing iteration 8...
MMD: 0.0724191963672638, Refined MMD: 0.02242761105298996
Processing iteration 9...
MMD: 0.02372489869594574, Refined MMD: 0.011162415146827698
Processing iteration 10...
MMD: 0.013257727026939392, Refined MMD: 0.014337599277496338


In [13]:
for k in range(10):
    print(f"Processing iteration {k+1}...")

    bf_kth_ps_path = os.path.join(bf_ps_samples_path, f'bf_{n}_ps_{k}.npy')
    bf_ps_samples = np.load(bf_kth_ps_path)

    kth_mcmc_samples_path = os.path.join(bf_mcmc_samples_path, f'bf_{n}_mcmc_samples_{k}.npy')
    mcmc_samples_total = np.load(kth_mcmc_samples_path)
    mcmc_samples = mcmc_samples_total[refinement_index, :]
    mcmc_samples = mcmc_samples.reshape(-1, d)

    mmd = compute_mmd(bf_ps_samples, true_posterior_samples, bandwidth)
    refined_mmd = compute_mmd(mcmc_samples, true_posterior_samples, bandwidth)

    print(f"MMD: {mmd.numpy()}, Refined MMD: {refined_mmd.numpy()}")
    with open(bf_replicate_mmd_csv, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([k+1, mmd.numpy(), refined_mmd.numpy()])

Processing iteration 1...
MMD: 0.021606475114822388, Refined MMD: 0.00946788489818573
Processing iteration 2...
MMD: 0.028450310230255127, Refined MMD: 0.020548544824123383
Processing iteration 3...
MMD: 0.019553527235984802, Refined MMD: 0.015402182936668396
Processing iteration 4...
MMD: 0.04083957523107529, Refined MMD: 0.03538571298122406
Processing iteration 5...
MMD: 0.05628848820924759, Refined MMD: 0.04439729452133179
Processing iteration 6...
MMD: 0.017284676432609558, Refined MMD: 0.009562507271766663
Processing iteration 7...
MMD: 0.03714586794376373, Refined MMD: 0.03212156891822815
Processing iteration 8...
MMD: 0.019086435437202454, Refined MMD: 0.02360723912715912
Processing iteration 9...
MMD: 0.11926950514316559, Refined MMD: 0.11988699436187744
Processing iteration 10...
MMD: 0.024847984313964844, Refined MMD: 0.020445458590984344


In [14]:
for k in range(10):
    print(f"Processing iteration {k+1}...")

    dnnabc_kth_ps_path = os.path.join(dnnabc_ps_samples_path, f'dnnabc_ps_{k}.npy')
    dnnabc_ps_samples = np.load(dnnabc_kth_ps_path)

    mmd = compute_mmd(dnnabc_ps_samples, true_posterior_samples, bandwidth)
    
    print(f"MMD: {mmd.numpy()}")
    with open(dnnabc_replicate_mmd_csv, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([k+1, mmd.numpy(), "N/A"])

for k in range(10):
    print(f"Processing iteration {k+1}...")

    w2abc_kth_ps_path = os.path.join(w2abc_ps_samples_path, f'w2_ps_{k}.npy')
    w2abc_ps_samples = np.load(w2abc_kth_ps_path)

    mmd = compute_mmd(w2abc_ps_samples, true_posterior_samples, bandwidth)
    
    print(f"MMD: {mmd.numpy()}")
    with open(w2abc_replicate_mmd_csv, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([k+1, mmd.numpy(), "N/A"])

Processing iteration 1...
MMD: 0.1252700686454773
Processing iteration 2...
MMD: 0.12448101490736008
Processing iteration 3...
MMD: 0.13632312417030334
Processing iteration 4...
MMD: 0.1113986074924469
Processing iteration 5...
MMD: 0.1152988076210022
Processing iteration 6...
MMD: 0.14161251485347748
Processing iteration 7...
MMD: 0.1372334361076355
Processing iteration 8...
MMD: 0.11662165820598602
Processing iteration 9...
MMD: 0.1264556497335434
Processing iteration 10...
MMD: 0.1227903813123703
Processing iteration 1...
MMD: 0.10402525216341019
Processing iteration 2...
MMD: 0.09879248589277267
Processing iteration 3...
MMD: 0.09921525418758392
Processing iteration 4...
MMD: 0.09667855501174927
Processing iteration 5...
MMD: 0.10927680879831314
Processing iteration 6...
MMD: 0.09087586402893066
Processing iteration 7...
MMD: 0.10613139718770981
Processing iteration 8...
MMD: 0.10045488178730011
Processing iteration 9...
MMD: 0.09289664030075073
Processing iteration 10...
MMD: 0.10

### update Figure

In [59]:
for k in range(10):
    print(f"Processing iteration {k+1}...")

    nn_kth_ps_path = os.path.join(nn_ps_samples_path, f'nn_{n}_ps_{k}.npy')
    nn_ps_samples = np.load(nn_kth_ps_path)

    kth_mcmc_samples_path = os.path.join(nn_mcmc_samples_path, f'nn_{n}_mcmc_samples_{k}.npy')
    mcmc_samples_total = np.load(kth_mcmc_samples_path)
    mcmc_samples = mcmc_samples_total[refinement_index, :]
    mcmc_samples = mcmc_samples.reshape(-1, d)

    kth_address = os.path.join(nn_fig_folder, f'nn_{n}_{k}.png')
    plot(mcmc_samples, nn_ps_samples, true_posterior_samples, kth_address, model="NN")


Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...


In [72]:
for k in range(10):
    print(f"Processing iteration {k+1}...")

    bf_kth_ps_path = os.path.join(bf_ps_samples_path, f'bf_50_ps_{k}.npy')
    bf_ps_samples = np.load(bf_kth_ps_path)

    kth_mcmc_samples_path = os.path.join(bf_mcmc_samples_path, f'bf_50_mcmc_samples_{k}.npy')
    mcmc_samples_total = np.load(kth_mcmc_samples_path)
    mcmc_samples = mcmc_samples_total[refinement_index, :]
    mcmc_samples = mcmc_samples.reshape(-1, d)

    kth_address = os.path.join(bf_fig_folder, f'bf_50_{k}.png')
    plot(mcmc_samples, bf_ps_samples, true_posterior_samples, kth_address, model="BF")


Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...


### ith element ABC-MCMC trajectory

In [73]:
nn_mcmc_trajectory_folder = os.path.join(replicate_result_folder, 'nn_mcmc_trajectory')
if not os.path.exists(nn_mcmc_trajectory_folder):
    os.makedirs(nn_mcmc_trajectory_folder)
bf_mcmc_trajectory_folder = os.path.join(replicate_result_folder, 'bf_mcmc_trajectory')
if not os.path.exists(bf_mcmc_trajectory_folder):
    os.makedirs(bf_mcmc_trajectory_folder)

In [74]:
for i in range(d):
    print(f"Processing element {i+1}...")
    ith_nn_mcmc_trajectory_subfolder = os.path.join(nn_mcmc_trajectory_folder, f'mcmc_trajectory_{i}')
    if not os.path.exists(ith_nn_mcmc_trajectory_subfolder):
        os.makedirs(ith_nn_mcmc_trajectory_subfolder)

    mcmc_index = []
    for k in range(10):
        print(f"Processing iteration {k+1}...")

        nn_kth_ps_path = os.path.join(nn_ps_samples_path, f'nn_50_ps_{k}.npy')
        nn_ps_samples = np.load(nn_kth_ps_path)
        
        kth_mcmc_samples_path = os.path.join(nn_mcmc_samples_path, f'nn_50_mcmc_samples_{k}.npy')
        mcmc_samples_total = np.load(kth_mcmc_samples_path)

        ik_subfolder = os.path.join(ith_nn_mcmc_trajectory_subfolder, f'{k+1}rounds')
        os.makedirs(ik_subfolder, exist_ok=True)


        for j in range(n_samples+burn_in):
            if (j+1) % thin == 0:
                mcmc_index.append(j)
                if len(mcmc_index) > 5:
                    mcmc_index.pop(0)
                
                mcmc_samples = mcmc_samples_total[mcmc_index, :]
                mcmc_samples = mcmc_samples.reshape(-1, d)
                ikj_address = os.path.join(ik_subfolder, f'{j+1}th_mcmc_samples.png')
                subplot(mcmc_samples, nn_ps_samples, true_posterior_samples, ikj_address, model="NN", element_index=i)


Processing element 1...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...
Processing element 2...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...
Processing element 3...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...
Processing element 4...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iterat

In [75]:
for i in range(d):
    print(f"Processing element {i+1}...")

    ith_bf_mcmc_trajectory_subfolder = os.path.join(bf_mcmc_trajectory_folder, f'mcmc_trajectory_{i}')
    if not os.path.exists(ith_bf_mcmc_trajectory_subfolder):
        os.makedirs(ith_bf_mcmc_trajectory_subfolder)

    mcmc_index = []
    for k in range(10):
        print(f"Processing iteration {k+1}...")

        bf_kth_ps_path = os.path.join(bf_ps_samples_path, f'bf_50_ps_{k}.npy')
        bf_ps_samples = np.load(bf_kth_ps_path)
        
        kth_mcmc_samples_path = os.path.join(bf_mcmc_samples_path, f'bf_50_mcmc_samples_{k}.npy')
        mcmc_samples_total = np.load(kth_mcmc_samples_path)

        ik_subfolder = os.path.join(ith_bf_mcmc_trajectory_subfolder, f'{k+1}rounds')
        os.makedirs(ik_subfolder, exist_ok=True)


        for j in range(n_samples+burn_in):
            if (j+1) % thin == 0:
                mcmc_index.append(j)
                if len(mcmc_index) > 5:
                    mcmc_index.pop(0)
                
                mcmc_samples = mcmc_samples_total[mcmc_index, :]
                mcmc_samples = mcmc_samples.reshape(-1, d)
                ikj_address = os.path.join(ik_subfolder, f'{j+1}th_mcmc_samples.png')
                subplot(mcmc_samples, bf_ps_samples, true_posterior_samples, ikj_address, model="BF", element_index=i)


Processing element 1...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...
Processing element 2...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...
Processing element 3...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iteration 5...
Processing iteration 6...
Processing iteration 7...
Processing iteration 8...
Processing iteration 9...
Processing iteration 10...
Processing element 4...
Processing iteration 1...
Processing iteration 2...
Processing iteration 3...
Processing iteration 4...
Processing iterat