In [1]:
import pickle
import tensorflow as tf
import numpy as np

import matplotlib.pyplot as plt

from train import get_setup

from tqdm import tqdm


import sys
sys.path.append("../../")
from reference_posteriors.two_moons.two_moons_lueckmann_numpy import analytic_posterior_numpy 
from inverse_kinematics import InverseKinematicsModel


from cmdstanpy import CmdStanModel
import logging


from reference_posteriors.gmm_bimodal import GMM, GMMSimulator


logger = logging.getLogger("cmdstanpy")
logger.addHandler(logging.NullHandler())
logger.propagate = False
logger.setLevel(logging.CRITICAL)

%load_ext autoreload
%autoreload 2

  from tqdm.autonotebook import tqdm


In [2]:
num_posterior_samples = 1000

In [3]:
tasks = ['gmm', 'twomoons', 'invkinematics']
task_names = {'gmm': 'GMM', 'twomoons': 'Two Moons', 'invkinematics': 'Kinematics'}
simulation_budgets = [1024]
estimators = ['ac', 'nsf', 'cmpe', 'fmpe']

colors = ['#440154', '#3b528b', '#21908dff', '#5dc962ff', '#fde725ff', '#fde725ff', '#fde725ff']

gmm_idx = 1
invkinematics_idx = 0

tf.random.set_seed(1234)

gmm_theta = np.array([-1.6, -1.0])
gmm_y = GMMSimulator(GMM)(tf.convert_to_tensor([gmm_theta]))[0].numpy().astype(np.float32)

iter_warmup = 2000

n_obs, data_dim = gmm_y.shape
param_dim = gmm_theta.shape[0]

iter_sampling = num_posterior_samples // 2

gmm_reference_samples = np.zeros((num_posterior_samples, param_dim))

stan_data = {"n_obs": n_obs, "data_dim": data_dim, "x": gmm_y}
model = CmdStanModel(stan_file="../../reference_posteriors/gmm_bimodal/gmm.stan")
fit = model.sample(
    data=stan_data,
    iter_warmup=iter_warmup,
    iter_sampling=iter_sampling,
    chains=1,
    inits = {"theta": gmm_theta.tolist()},
    show_progress=False
)
posterior_samples_chain = fit.stan_variable("theta")
gmm_reference_samples = np.concatenate([posterior_samples_chain, -1.0 * posterior_samples_chain], axis=0)


test_instances = {
    'gmm': {'summary_conditions': gmm_y[np.newaxis, ...]},
    'twomoons': {'direct_conditions': np.array([[0, 0]]).astype(np.float32)},
    'invkinematics': {'direct_conditions': np.array([[0, 1.5]]).astype(np.float32)},
}

reference_posteriors = {
    'gmm': gmm_reference_samples,
    'twomoons': analytic_posterior_numpy(test_instances['twomoons']['direct_conditions'][0], num_posterior_samples, rng=np.random.default_rng(seed=1234)),
    'invkinematics': test_instances['invkinematics']['direct_conditions'][0][::-1],
}

inverse_kinematics_abc = pickle.load(open('./data/invkinematics_showcase_abc.pkl', 'rb'))[:num_posterior_samples]

plot_settings = {
    'ac': {'name': 'ACF', 'color': colors[0]},
    'nsf': {'name': 'NSF', 'color': colors[1]},
    'fmpe10': {'name': 'FMPE 10#', 'color': colors[3]},
    'fmpe30': {'name': 'FMPE 30#', 'color': colors[4]},
    'fmpe': {'name': 'FMPE 1000#', 'color': colors[2]},
    'cmpe10': {'name': 'CMPE 10#', 'color': colors[5]},
    'cmpe30': {'name': 'CMPE 30#', 'color': colors[6]},
}

In [4]:
import time

def sample_timed(trainer, num_runs=3, **kwargs):
    t_min = np.inf

    for _ in range(num_runs):
        tic = time.time()
        samples = trainer.amortizer.sample(**kwargs)
        toc = time.time()
        t = toc - tic
        if t < t_min:
            t_min = t
            samples_t_min = samples

    return samples_t_min, t_min

In [5]:
# evaluate the estimators on the test data
total = len(tasks) * len(simulation_budgets) * len(estimators)
run_idx = 0
eval_dict = {task: {budget: {estimator: {} for estimator in estimators} for budget in simulation_budgets} for task in tasks}
num_runs_timed = 3
with tqdm(total=total) as pbar:
    for task in tasks:
        for budget in simulation_budgets:
            for estimator in estimators:
                train_data_full = pickle.load(open(f'./data/{task}_train_data.pkl', 'rb'))
                train_data = {
                    'sim_data': train_data_full.get('sim_data')[:budget],
                    'prior_draws': train_data_full.get('prior_draws')[:budget],
                }
                sigma2 = tf.math.reduce_variance(tf.constant(train_data["prior_draws"], dtype=tf.float32), axis=0, keepdims=True)
                ckpt_path = f'./checkpoints/{task}_{estimator}_{budget}_run{run_idx}'
                trainer, settings = get_setup(task, estimator, sigma2, budget, ckpt_path)
                eval_data = test_instances[task]
                if estimator == 'cmpe':
                    eval_dict[task][budget]['cmpe10'] = {}
                    eval_dict[task][budget]['cmpe30'] = {}
                    eval_dict[task][budget]['cmpe10']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_steps=10, n_samples=num_posterior_samples, to_numpy=False)
                    eval_dict[task][budget]['cmpe30']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_steps=30, n_samples=num_posterior_samples, to_numpy=False)
                elif estimator == 'fmpe':
                    eval_dict[task][budget]['fmpe10'] = {}
                    eval_dict[task][budget]['fmpe30'] = {}
                    eval_dict[task][budget]['fmpe']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_samples=num_posterior_samples, to_numpy=False)
                    eval_dict[task][budget]['fmpe10']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, step_size=1.0/10.0, n_samples=num_posterior_samples, to_numpy=False)
                    eval_dict[task][budget]['fmpe30']['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, step_size=1.0/30.0, n_samples=num_posterior_samples, to_numpy=False)
                else:
                    eval_dict[task][budget][estimator]['posterior_samples'] = sample_timed(trainer, num_runs=num_runs_timed, input_dict=eval_data, n_samples=num_posterior_samples, to_numpy=False)
                pbar.update(1)

  0%|          | 0/12 [00:00<?, ?it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
  8%|▊         | 1/12 [00:00<00:09,  1.18it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
 17%|█▋        | 2/12 [00:01<00:06,  1.44it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
 25%|██▌       | 3/12 [00:02<00:08,  1.01it/s]INFO:root:Trainer initialization: No generative model provided. Only offline learning mode is available!
INFO:root:Initialized empty loss history.
INFO:root:Initialized networks from scratch.
 33%|███▎      | 4/12 [00:14<00:42,  5.36s/it]INFO:root:

### Example Grid

In [None]:
def plot(ax, reference, approximate, task, reference_color=(0.8, 0.4, 0.4), approximate_color=(1, 1, 1), **kwargs):
    if task == 'twomoons' or task == 'gmm':
        if reference is not None:
            ax.scatter(reference[:, 0], reference[:, 1], color=reference_color, **kwargs)
        if approximate is not None:
            ax.scatter(approximate[:, 0], approximate[:, 1], color=approximate_color, **kwargs)
    elif task == 'invkinematics':
        if approximate is not None:
            if reference_color == (0.8, 0.4, 0.4):
                linecolors = [(1,1,1), (0.8, 0.8, 0.8), (0.7, 0.7, 0.7)]
            else:
                linecolors = [(0.8, 0.4, 0.4)] * 3
            m = InverseKinematicsModel(linecolors=linecolors)
            m.update_plot_ax(ax, approximate, reference)#, target_label=r'$\theta^*$')
    else:
        raise ValueError(f'Unknown task {task}')

    if task == 'twomoons':
        ax.set_xlim([-0.5, 0.5])
        ax.set_ylim([-0.5, 0.5])
    elif task == 'gmm':
        gmm_limit = 2.5
        ax.set_xlim([-gmm_limit, gmm_limit])
        ax.set_ylim([-gmm_limit, gmm_limit])

    return ax


In [None]:
nrows = len(tasks)
ncols = len(plot_settings) + 1

scatter_kws = {
    "alpha": 0.20,
    "rasterized": True,
    "s": 0.7,
    "marker": "D",
}

f, axes = plt.subplots(nrows, ncols, figsize=(ncols*2, nrows*2), subplot_kw=dict(box_aspect=1), squeeze=False)


for i, task in enumerate(tasks):
    axes[i, 0].set_ylabel(task_names[task], rotation=90, size='xx-large')

    # Plot Reference
    if task == 'invkinematics':
        axes[i, 0] = plot(axes[i, 0], reference_posteriors[task], inverse_kinematics_abc, task, reference_color='custom', **scatter_kws)
    else:
        axes[i, 0] = plot(axes[i, 0], reference_posteriors[task], None, task, **scatter_kws)

    # Plot Approximate
    for j, estimator in enumerate(plot_settings.keys(), 1):
        if i == 0:
            axes[i, j].set_title(plot_settings[estimator]['name'], size='xx-large')
        posterior_samples, sampling_time = eval_dict[task][budget][estimator]['posterior_samples']
        sampling_time_1000 = sampling_time / posterior_samples.shape[0] * 1000
        axes[i, j] = plot(axes[i, j], reference_posteriors[task], posterior_samples.numpy(), task, **scatter_kws)

        axes[i, j].annotate(text=f'00000ms', 
                            xy=(0.95, 0.06), 
                            xycoords='axes fraction',
                            color='white',
                            horizontalalignment='right',
                            bbox=dict(facecolor='white'),
                            fontsize='x-large'
                            )
        
        #axes[i, j].annotate(text=f'{sampling_time_1000:.2f} sec', 
        axes[i, j].annotate(text=f'{int(sampling_time_1000*1000)}ms', 
                            xy=(0.95, 0.06), 
                            xycoords='axes fraction',
                            horizontalalignment='right',
                            fontsize='x-large'
                            )

axes[0, 0].set_title("Reference", size='xx-large')

for ax in axes.flat:
    ax.grid(False)
    ax.set_facecolor((0 / 255, 32 / 255, 64 / 255, 1.0))
    ax.get_xaxis().set_ticks([])
    ax.get_yaxis().set_ticks([])
    ax.spines["bottom"].set_alpha(0.0)
    ax.spines["top"].set_alpha(0.0)
    ax.spines["right"].set_alpha(0.0)
    ax.spines["left"].set_alpha(0.0)
    ax.set_aspect('equal')

f.tight_layout()

f.savefig('./figures/benchmark_showcase.pdf', dpi=300, bbox_inches='tight')