In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
import os
import csv

In [2]:
replicate_result_folder = 'replicate_results'
if not os.path.exists(replicate_result_folder):
    os.makedirs(replicate_result_folder)

In [3]:
d = 5

# color setting
truth_color = "#FF6B6B"
est_color = "#4D96FF"
refined_color = "#6BCB77"
upper_labels=["\\theta_1","\\theta_2","\\theta_3","\\theta_4","\\theta_5"]

In [4]:
nn_ps_samples_path = os.path.join(os.getcwd(), 'result','nn_ps')
nn_mcmc_samples_path = os.path.join(os.getcwd(), 'result','nn_mcmc_samples')

In [5]:
true_posterior_path = os.path.join(os.getcwd(), 'data','ps.npy')
true_posterior_samples = np.load(true_posterior_path)
print(true_posterior_samples.shape)
bandwidth_path = os.path.join(os.getcwd(), 'data','h_mmd.npy')
bandwidth = np.load(bandwidth_path)

(10000, 5)


In [8]:
def compute_mmd(x_,y_, bandwidth):
    """
    Compute the mean and standard deviation of the MMD between x and y.
    :param x_: (N_x,d)
    :param y_: (N_y,d)
    :return:
    """
    x = tf.convert_to_tensor(x_, dtype=tf.float32)
    y = tf.convert_to_tensor(y_, dtype=tf.float32)

    xx = tf.matmul(x,tf.transpose(x))
    xy = tf.matmul(x,tf.transpose(y))
    yy = tf.matmul(y,tf.transpose(y))

    rx = tf.reduce_sum(tf.square(x), axis=-1, keepdims=True)
    ry = tf.reduce_sum(tf.square(y), axis=-1, keepdims=True)

    se_xx = rx - 2*xx + tf.transpose(rx)
    se_xy = rx - 2*xy + tf.transpose(ry)
    se_yy = ry - 2*yy + tf.transpose(ry)

    K_xx = tf.exp(-0.5 * se_xx / (bandwidth**2) )
    K_xy = tf.exp( -0.5 * se_xy / (bandwidth**2) )
    K_yy = tf.exp( -0.5 * se_yy / (bandwidth**2) )

    return tf.reduce_mean(K_xx) - 2*tf.reduce_mean(K_xy) + tf.reduce_mean(K_yy)

## update mmd

In [6]:
nn_replicate_mmd_csv = os.path.join(replicate_result_folder, 'mmd_replicate_mmd.csv')
with open(nn_replicate_mmd_csv, "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["Iteration", "MMD","refined MMD"])

In [7]:
burn_in =  250
n_samples = 100
thin = 20
refinement_index = []

for i in range(burn_in+n_samples):
    if i >= burn_in and (i-burn_in) % thin == 0:
        refinement_index.append(i)

print(f"Refinement index: {refinement_index}")

Refinement index: [250, 270, 290, 310, 330]


In [9]:
for k in range(10):
    print(f"Processing iteration {k+1}...")

    nn_kth_ps_path = os.path.join(nn_ps_samples_path, f'nn_50_ps_{k}.npy')
    nn_ps_samples = np.load(nn_kth_ps_path)

    kth_mcmc_samples_path = os.path.join(nn_mcmc_samples_path, f'nn_50_mcmc_samples_{k}.npy')
    mcmc_samples_total = np.load(kth_mcmc_samples_path)
    mcmc_samples = mcmc_samples_total[refinement_index, :]
    mcmc_samples = mcmc_samples.reshape(-1, d)

    mmd = compute_mmd(nn_ps_samples, true_posterior_samples, bandwidth)
    refined_mmd = compute_mmd(mcmc_samples, true_posterior_samples, bandwidth)

    print(f"MMD: {mmd.numpy()}, Refined MMD: {refined_mmd.numpy()}")
    with open(nn_replicate_mmd_csv, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([k+1, mmd.numpy(), refined_mmd.numpy()])    

Processing iteration 1...
MMD: 0.032776251435279846, Refined MMD: 0.021373651921749115
Processing iteration 2...
MMD: 0.05185917019844055, Refined MMD: 0.019209280610084534
Processing iteration 3...
MMD: 0.03253833204507828, Refined MMD: 0.02461998164653778
Processing iteration 4...
MMD: 0.032508596777915955, Refined MMD: 0.01012122631072998
Processing iteration 5...
MMD: 0.04901681840419769, Refined MMD: 0.02413634955883026
Processing iteration 6...
MMD: 0.047604821622371674, Refined MMD: 0.01594504714012146
Processing iteration 7...
MMD: 0.06179584562778473, Refined MMD: 0.025123149156570435
Processing iteration 8...
MMD: 0.022596657276153564, Refined MMD: 0.013932481408119202
Processing iteration 9...
MMD: 0.03743452578783035, Refined MMD: 0.01525270938873291
Processing iteration 10...
MMD: 0.022383302450180054, Refined MMD: 0.012264631688594818
