In [None]:
import os, sys
sys.path.insert(0, os.path.abspath("../scr"))
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
from sbi import analysis as analysis
from sbi.inference import prepare_for_sbi
from sbi.inference import prepare_for_sbi, simulate_for_sbi
from sbi.analysis import check_sbc, run_sbc, get_nltp, sbc_rank_plot
import sbi.utils as utils

In [None]:
from simulator import smfe_simulator_mm
import config

In [None]:
lower_limits = [
    config.logD_lims[0], config.k_lims[0],
    *(config.spline_lims[0] for i in range(config.N_knots_prior))
]

upper_limits = [
    config.logD_lims[1], config.k_lims[1],
    *(config.spline_lims[1] for i in range(config.N_knots_prior))
]

prior = utils.BoxUniform(
    low=torch.tensor(lower_limits),
    high=torch.tensor(upper_limits)
)

simulator, prior = prepare_for_sbi(smfe_simulator_mm, prior)

In [None]:
num_sbc_runs = 2000  # choose a number of sbc runs, should be ~100s or ideally 1000
# generate ground truth parameters and corresponding simulated observations for SBC.
thetas, xs = simulate_for_sbi(
        simulator,
        prior,
        num_simulations=num_sbc_runs,
        num_workers=20,
        show_progress_bar=True
    )

In [None]:
with open(f'../scr/mmatrix_posterior.pkl', 'rb') as handle:
    posterior = pickle.load(handle)

In [None]:
num_posterior_samples = 2000
ranks, dap_samples = run_sbc(thetas, xs, posterior, num_posterior_samples=num_posterior_samples)

In [None]:
f, ax = sbc_rank_plot(ranks, 1_000, plot_type="cdf")

In [None]:
check_stats = check_sbc(
    ranks, thetas, dap_samples, num_posterior_samples=num_posterior_samples, num_c2st_repetitions=10
)

In [None]:
check_stats

In [None]:
plt.hlines(0.5, xmin=0, xmax=13, color='red', linestyle='--')
plt.plot(check_stats['c2st_dap'], 'ob', color='green', label='C2ST DAP vs Prior')
plt.plot(check_stats['c2st_ranks'], 'ob', label='C2ST Ranks vs Uniform')
plt.ylim(1, 0.4)
plt.legend(fontsize=16)
_ = plt.xticks(ticks=np.arange(0, 13), labels=[r'$D_q/D_x$', '$k$', *(f'$G_0(x_{i})$' for i in range(11))], rotation='vertical', fontsize=16)
plt.ylabel('C2ST', fontsize=16)

In [None]:
f, ax = sbc_rank_plot(
    ranks=ranks,
    num_posterior_samples=num_posterior_samples,
    plot_type="hist",
    num_bins=30,  # by passing None we use a heuristic for the number of bins.
)
