In [1]:
import matplotlib.pyplot as plt
from pathlib import Path
import pickle
import numpy as np
import seaborn as sns
import os
import bayesflow as bf
import tensorflow as tf
from sc_abi.sc_amortizers import AmortizedPosteriorLikelihoodSC
from sc_abi.sc_schedules import ZeroOneSchedule
from tasks.two_moons import generative_model, prior, get_amortizer_arguments

from tasks.two_moons import analytic_posterior_numpy

  from tqdm.autonotebook import tqdm
INFO:root:Performing 2 pilot runs with the anonymous model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.


In [7]:
simulation_budgets = 512, 1024, 2048, 4096
run_id = 1
TASK_NAME = "two_moons"

os.makedirs(f'./computations/{TASK_NAME}', exist_ok=True)
PLOT_DIR = Path("plots", TASK_NAME)
CHECKPOINT_DIR = Path("checkpoints/", TASK_NAME)

In [3]:

def load_trainers(budget):
    trainer_nple = bf.trainers.Trainer(
            amortizer=bf.amortizers.AmortizedPosteriorLikelihood(**get_amortizer_arguments()),
            generative_model=generative_model,
            default_lr=5e-4,
            memory=False,
            checkpoint_path=CHECKPOINT_DIR / str(budget) / "nple" / str(run_id),
            configurator=bf.benchmarks.Benchmark('two_moons', 'joint').configurator,
            max_to_keep=1,
        )


    # SC-ABI trainer
    lambda_scheduler = ZeroOneSchedule(threshold_step=32*100)
    trainer_sc10 = bf.trainers.Trainer(
        amortizer=AmortizedPosteriorLikelihoodSC(
            **get_amortizer_arguments(),
            prior=prior,
            n_consistency_samples=10,
            lambda_schedule=lambda_scheduler,
            theta_clip_value_min=-2.0,
            theta_clip_value_max=2.0 - 1e-5, # uniform distribution excludes upper limit
            ),
        generative_model=generative_model,
        default_lr=5e-4,
        memory=False,
        checkpoint_path=CHECKPOINT_DIR / str(budget) / "sc" / str(run_id),
        configurator=bf.benchmarks.Benchmark('two_moons', 'joint').configurator,
        max_to_keep=1,
    )
    return trainer_nple, trainer_sc10

In [4]:
n_eval = 1000
n_draws = 1000

eval_data = generative_model(n_eval)

names = ['NPLE', 'SC-NPLE']
likelihood_estimates = {name: {budget: {'samples': None} for budget in simulation_budgets} for name in names}


In [5]:
for budget in simulation_budgets:
    trainer_nple, trainer_sc10 = load_trainers(budget)
    for name, trainer in zip(names, [trainer_nple, trainer_sc10]):
        eval_data_config = trainer.configurator(eval_data)
        y_est = trainer.amortizer.sample_data(eval_data_config, n_samples=n_draws)
        log_lik_y_true = trainer.amortizer.log_likelihood(eval_data_config)
        likelihood_estimates[name][budget]['y_est'] = y_est
        likelihood_estimates[name][budget]['log_lik_y_true'] = log_lik_y_true

INFO:root:Performing 2 pilot runs with the two_moons model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:No optional prior non-batchable context provided.
INFO:root:No optional prior batchable context provided.
INFO:root:No optional simulation non-batchable context provided.
INFO:root:No optional simulation batchable context provided.
INFO:root:Loaded loss history from checkpoints/two_moons/256/nple/1/history_201.pkl.
INFO:root:Networks loaded from checkpoints/two_moons/256/nple/1/ckpt-201
INFO:root:Performing a consistency check with provided components...
INFO:root:Done.
INFO:root:Performing 2 pilot runs with the two_moons model...
INFO:root:Shape of parameter batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:Shape of simulation batch after 2 pilot simulations: (batch_size = 2, 2)
INFO:root:No optional prior non-batchable context provided.
I

In [11]:
y_true = eval_data['sim_data']
for name in names:
    print(fr'{name}', end="")
    for budget in simulation_budgets:
        log_lik_y_true = likelihood_estimates[name][budget]['log_lik_y_true']
        
        # log_lik_y_true_median = np.median(log_lik_y_true)
        # log_lik_y_true_lower = np.percentile(log_lik_y_true, 2.5)
        # log_lik_y_true_upper = np.percentile(log_lik_y_true, 97.5)

        #print(f'{name} budget={budget} log_lik_y_true_median={log_lik_y_true_median:.2f}[{log_lik_y_true_lower:.2f}, {log_lik_y_true_upper:.2f}]')
        # print(rf' & {log_lik_y_true_median:.2f} [{log_lik_y_true_lower:.2f}, {log_lik_y_true_upper:.2f}]', end="")
        # print(rf' & {log_lik_y_true_median:.2f}', end="")
        #print(fr' & \multirow{{2}}{{*}}{{\parbox{{1.5cm}}{{{log_lik_y_true_median:.2f} \\ \scriptsize [{log_lik_y_true_lower:.2f}, {log_lik_y_true_upper:.2f}]}}}} ', end="")
        
        log_lik_y_true_mean = np.mean(log_lik_y_true)
        log_lik_y_true_se = np.std(log_lik_y_true) / np.sqrt(len(log_lik_y_true))
        # print(rf' & {log_lik_y_true_mean:.2f}', end="")
        print(rf' & ${log_lik_y_true_mean:.2f} {{\pm}} {log_lik_y_true_se:.2f}$', end="")

        # size of CI        
        #print(rf' & {log_lik_y_true_upper - log_lik_y_true_lower :.2f}', end="")
    # print(r'\\')
    # for budget in simulation_budgets:
    #     log_lik_y_true = likelihood_estimates[name][budget]['log_lik_y_true']
        
    #     # log_lik_y_true_median = np.median(log_lik_y_true)
    #     # log_lik_y_true_lower = np.percentile(log_lik_y_true, 2.5)
    #     # log_lik_y_true_upper = np.percentile(log_lik_y_true, 97.5)
    #     log_lik_y_true_mean = np.mean(log_lik_y_true)
    #     log_lik_y_true_se = np.std(log_lik_y_true) / np.sqrt(len(log_lik_y_true))
    #     #print(rf' &\scriptsize\scriptsize\raisebox{{0.10cm}}{{[{log_lik_y_true_lower:.2f}, {log_lik_y_true_upper:.2f}]}}', end="")
    #     print(rf' &\scriptsize\scriptsize\raisebox{{0.10cm}}{{({log_lik_y_true_se:.2f})}}', end="")
    print(r'\\')

NPLE & $3.15 {\pm} 0.03$ & $3.18 {\pm} 0.03$ & $2.88 {\pm} 0.04$ & $2.91 {\pm} 0.05$\\
SC-NPLE & $3.14 {\pm} 0.02$ & $3.45 {\pm} 0.02$ & $3.71 {\pm} 0.02$ & $3.90 {\pm} 0.02$\\
