In [1]:
import matplotlib.pyplot as plt
from pathlib import Path
import pickle
import numpy as np
import seaborn as sns
import os
import bayesflow as bf
import tensorflow as tf
from sc_abi.sc_amortizers import AmortizedPosteriorLikelihoodSC
from sc_abi.sc_schedules import ZeroOneSchedule
from tasks.two_moons import generative_model, prior, get_amortizer_arguments

from tasks.two_moons import analytic_posterior_numpy

  from tqdm.autonotebook import tqdm
INFO:root:Performing 2 pilot runs with the anonymous model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.


In [2]:
simulation_budgets = 512, 1024, 2048, 4096
run_id = 1
TASK_NAME = "two_moons"

os.makedirs(f'./computations/{TASK_NAME}', exist_ok=True)
PLOT_DIR = Path("plots", TASK_NAME)
CHECKPOINT_DIR = Path("checkpoints/", TASK_NAME)

In [3]:
def load_trainers(budget):
    trainer_nple = bf.trainers.Trainer(
            amortizer=bf.amortizers.AmortizedPosteriorLikelihood(**get_amortizer_arguments()),
            generative_model=generative_model,
            default_lr=5e-4,
            memory=False,
            checkpoint_path=CHECKPOINT_DIR / str(budget) / "nple" / str(run_id),
            configurator=bf.benchmarks.Benchmark('two_moons', 'joint').configurator,
            max_to_keep=1,
        )


    # SC-ABI trainer
    lambda_scheduler = ZeroOneSchedule(threshold_step=32*100)
    trainer_sc10 = bf.trainers.Trainer(
        amortizer=AmortizedPosteriorLikelihoodSC(
            **get_amortizer_arguments(),
            prior=prior,
            n_consistency_samples=10,
            lambda_schedule=lambda_scheduler,
            theta_clip_value_min=-2.0,
            theta_clip_value_max=2.0 - 1e-5, # uniform distribution excludes upper limit
            ),
        generative_model=generative_model,
        default_lr=5e-4,
        memory=False,
        checkpoint_path=CHECKPOINT_DIR / str(budget) / "sc" / str(run_id),
        configurator=bf.benchmarks.Benchmark('two_moons', 'joint').configurator,
        max_to_keep=1,
    )
    return trainer_nple, trainer_sc10

In [4]:
n_eval = 1000
n_draws = 1000

eval_data = generative_model(n_eval)

names = ['NPLE', 'SC-NPLE']
likelihood_estimates = {name: {budget: {'samples': None} for budget in simulation_budgets} for name in names}


In [5]:
logml_dict = {name: {budget: {'logml': None} for budget in simulation_budgets} for name in names}

for budget in simulation_budgets:
    trainer_nple, trainer_sc10 = load_trainers(budget)
    for name, trainer in zip(names, [trainer_nple, trainer_sc10]):
        n_eval = eval_data['prior_draws'].shape[0]
        logml = np.zeros((n_eval, n_draws), dtype=np.float32)
        eval_data_config = trainer.configurator(eval_data)
        theta_est = trainer.amortizer.sample_parameters(eval_data_config, n_samples=n_draws)


        for i in range(n_draws):
            est_dict = {
                'posterior_inputs': {'direct_conditions': eval_data['sim_data'], 'parameters': theta_est[:, i, :]},
                'likelihood_inputs': {'conditions': theta_est[:, i, :], 'observables': eval_data['sim_data']}
            }
            log_prior = np.array(prior.log_prob(theta_est[:, i, :]))
            log_likelihood = np.array(trainer.amortizer.log_likelihood(est_dict))
            log_posterior = np.array(trainer.amortizer.log_posterior(est_dict))

            logml[:, i] = log_prior + log_likelihood - log_posterior
        logml_dict[name][budget]['logml'] = logml
        

INFO:root:Performing 2 pilot runs with the two_moons model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:Loaded loss history from checkpoints/two_moons/512/nple/1/history_201.pkl.
INFO:root:Networks loaded from checkpoints/two_moons/512/nple/1/ckpt-201
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.
INFO:root:Performing 2 pilot runs with the two_moons model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:No optional prior non-batchable context provided.
I

In [6]:
for name in names:
    print(fr'{name}', end="")
    for budget in simulation_budgets:
        logml = logml_dict[name][budget]['logml']
        # turn -inf to nan
        logml[logml == -np.inf] = np.nan
        lower_ci = np.nanpercentile(logml, 2.5, axis=1, )
        upper_ci = np.nanpercentile(logml, 97.5, axis=1)
        sharpness = upper_ci - lower_ci

        # sharpness_median = np.median(sharpness)
        # sharpness_lower = np.percentile(sharpness, 2.5)
        # sharpness_upper = np.percentile(sharpness, 97.5)
        sharpness_mean = np.mean(sharpness)
        sharpness_se = np.std(sharpness) / np.sqrt(sharpness.shape[0])
        print(rf' & ${sharpness_mean:.2f} {{\pm}} {sharpness_se:.2f}$', end="")
        
        
    # print(r'\\')
    # for budget in simulation_budgets:
    #     logml = logml_dict[name][budget]['logml']
    #     logml[logml == -np.inf] = np.nan
    #     lower_ci = np.nanpercentile(logml, 2.5, axis=1)
    #     upper_ci = np.nanpercentile(logml, 97.5, axis=1)
    #     sharpness = upper_ci - lower_ci

    #     sharpness_median = np.median(sharpness)
    #     sharpness_lower = np.percentile(sharpness, 2.5)
    #     sharpness_upper = np.percentile(sharpness, 97.5)
        
    #     print(rf' &\scriptsize\scriptsize\raisebox{{0.10cm}}{{[{sharpness_lower:.2f}, {sharpness_upper:.2f}]}}', end="")
    print(r'\\')

NPLE & $6.51 {\pm} 0.11$ & $7.28 {\pm} 0.10$ & $9.07 {\pm} 0.06$ & $10.21 {\pm} 0.08$\\
SC-NPLE & $1.70 {\pm} 0.02$ & $1.37 {\pm} 0.02$ & $1.21 {\pm} 0.01$ & $1.14 {\pm} 0.01$\\
