In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import os 
import pickle
import sys 
import time
import tqdm

sys.path.append('../../')
from model_comparison.utils import *
from model_comparison.mdns import *
from model_comparison.models import PoissonModel, NegativeBinomialModel
%matplotlib inline

In [None]:
mpl_params = {'legend.fontsize': 15,
              'legend.frameon': False,
                      'axes.titlesize': 20,
                      'axes.labelsize': 17,
                      'xtick.labelsize': 12,
                      'ytick.labelsize': 12,
             'figure.figsize' : (18, 5)}

mpl.rcParams.update(mpl_params)

### Load trained NB posterior

In [None]:
folder = '../data/'
fn = 'learned_posterior_nbmodel_ntrain100000.p'
time_stamp = time.strftime('%Y%m%d%H%M_')

with open(os.path.join(folder, fn), 'rb') as f: 
    d = pickle.load(f)
    
# set the seed for generating new test data 
seed = 5
np.random.seed(seed)

In [None]:
# priors 
prior_k = d['prior_k']
prior_theta = d['prior_theta']
sample_size = d['sample_size']
ntrain = d['ntrain']
param_norm = d['param_norm']
data_norm = d['data_norm']

model_nb = d['model']
model_params_mdn = d['mdn']
loss_trace = d['trainer'].loss_trace

In [None]:
plt.figure(figsize=(18, 5))
plt.plot(loss_trace);

## Sample new test data 

In [None]:
ntest = 400
params_test = np.vstack((prior_k.rvs(size=ntest), prior_theta.rvs(size=ntest))).T
x_test = model_nb.gen(params_test)

In [None]:
sx_test = calculate_stats_toy_examples(x_test)
sx_test_zt, _ = normalize(sx_test, data_norm)
params_test_zt, _ = normalize(params_test, param_norm)

## Loop over test samples for 
- #### Quantiles 
- #### posterior mean differences 
- #### $D_{KL}$
- #### credible intervals 

In [None]:
# quantiles, dkl ratios and credible intervals for every dimension (2)
dkl_ratios = np.zeros((3, ntest))

qis = np.zeros_like(dkl_ratios)
qis_hat = np.zeros_like(dkl_ratios)

mus_hat = np.zeros_like(dkl_ratios)
mus_exact = np.zeros_like(dkl_ratios)
stds_hat = np.zeros_like(dkl_ratios)
stds_exact = np.zeros_like(dkl_ratios)

credible_intervals = np.arange(0.05, 1., 0.05)
marginal_ci_counts = np.zeros((2, ntest, credible_intervals.size))
marginal_ci_counts_hat = np.zeros((2, ntest, credible_intervals.size))

covariances = []
covariances_hat = []

ms = []
ms_hat = []
ps = []
ps_hat = []

priors = [prior_k, prior_theta]
joint_prior = JointGammaPrior(prior_k, prior_theta)
# for every test sample 
fails = []

with tqdm.tqdm(total=ntest) as pbar: 
    for ii, (thetao_i, sxo_i, xo_i) in enumerate(zip(params_test, sx_test_zt, x_test)): 

        # predict the posterior
        post_hat_zt = model_params_mdn.predict(sxo_i.reshape(1, -1))
        dd = post_hat_zt.get_dd_object()
        # transform back to original parameter range
        post_hat = post_hat_zt.ztrans_inv(param_norm[0], param_norm[1])
        marginals_hat = post_hat.get_marginals()
        
        # the exact posterior 
        post_exact = NBExactPosterior(xo_i, prior_k, prior_theta)
        post_exact.calculat_exact_posterior(verbose=False, prec=1e-6, n_samples=300)
        
        marginals_exact = post_exact.get_marginals()
        ps.append(post_exact)
        ps_hat.append(post_hat)
        ms.append(marginals_exact)
        ms_hat.append(marginals_hat)
       
        try:
            pbar.update()
            # perform check for marginals         
            for vi, (m, mhat, th) in enumerate(zip(marginals_exact, marginals_hat, thetao_i)):             
                # means and std 
                # generate samples for estimating the mean and std
                ss = m.gen(10000)
                mus_exact[vi, ii], stds_exact[vi, ii] = m.mean, m.std
                mus_hat[vi, ii], stds_hat[vi, ii] = mhat.mean, mhat.std

                # quantiles 
                qis[vi, ii] = m.cdf(th)
                qis_hat[vi, ii] = mhat.get_quantile(th)

                # credible intervals
                marginal_ci_counts[vi, ii, :] = m.get_credible_interval_counts(th, credible_intervals)
                marginal_ci_counts_hat[vi, ii, :] = mhat.get_credible_interval_counts(th, credible_intervals)

                # DKL 
                baseline = calculate_dkl_1D_scipy(m.pdf_array, priors[vi].pdf(m.support))
                (dkl, err) = calculate_dkl_monte_carlo(np.array(ss), m.pdf, mhat.eval_numpy)
                dkl_ratios[vi, ii] = dkl / baseline

            # perform checks for joint 
            vi = 2

            # quantiles 
            qis[vi, ii] = post_exact.cdf(thetao_i.reshape(1, -1))
            qis_hat[vi, ii] = post_hat.get_quantile(thetao_i.reshape(1, -1))

            # DKL 
            post_samples = post_exact.gen(20000)
            (baseline, err) = calculate_dkl_monte_carlo(post_samples, post_exact.pdf, joint_prior.pdf)
            dkl = calculate_dkl_monte_carlo(post_samples, post_exact.pdf, post_hat.eval_numpy)
            (dkl_ratios[vi, ii], err) = dkl / baseline

            # covariances
            covariances.append(post_exact.cov)
            covariances_hat.append(post_hat.get_covariance_matrix())     
        except: 
            fails.append(ii)
            continue
fails

In [None]:
stds_exact_zt = (np.array(stds_exact).T - np.array(stds_exact).mean(axis=1)).T
stds_hat_zt = (np.array(stds_hat).T - np.array(stds_hat).mean(axis=1)).T

mus_exact_zt = (np.array(mus_exact).T - np.array(mus_exact).mean(axis=1)).T
mus_hat_zt = (np.array(mus_hat).T - np.array(mus_hat).mean(axis=1)).T

In [None]:
# save posterior checks results 
result_dict = dict(qis=qis, qis_hat=qis_hat, dkl_ratios=dkl_ratios,
                  marginal_ci_counts=marginal_ci_counts, 
                  marginal_ci_counts_hat=marginal_ci_counts_hat, 
                  fails=fails, 
                  ntest=ntest, 
                  mus_exact=mus_exact, mus_hat=mus_hat, 
                  stds_exact=stds_exact, stds_hat=stds_hat, 
                  credible_intervals=credible_intervals, 
                  covariances=covariances, 
                  covariances_hat=covariances_hat, 
                  params_test=params_test, 
                  sx_test_zt=sx_test_zt, 
                  x_test=x_test)

fn = time_stamp + 'posterior_checks_results_NB_ntrain{}_ns{}'.format(ntrain, sample_size) + '.p'
with open(os.path.join('../data', fn), 'wb') as outfile: 
    pickle.dump(result_dict, outfile, pickle.HIGHEST_PROTOCOL)

### Analyze failures

In [None]:
for i, pi in enumerate(fails): 
    p = ps[pi]
    tho = params_test[pi]
    print(tho)
    p.calculated = False
    p.calculat_exact_posterior(tho, n_samples=200, prec=1e-6)
    margs = p.get_marginals()
    plt.subplot(1, 3, i + 1)
#     plt.imshow(p.joint_pdf, origin='lower')
    plt.plot(margs[0].support, margs[0].cdf_array)
    plt.plot(margs[0].support, margs[0].pdf_array)
    

In [None]:
fig1, ax = plt.subplots(2, 3, figsize=(18, 8))
labels = [r'$k$', r'$\theta$']

# exclude fails test params
mask = np.logical_not(np.zeros(ntest))
mask[fails] = False

# ci_probs = marginal_ci_counts[:, mask, :].mean(axis=1)
# ci_probs_hat = marginal_ci_counts_hat[:, mask, :].mean(axis=1)

dkl_bins = np.linspace(0, 1, 20)

for i in range(2):
    line = np.linspace(mus_exact_zt[:, mask].min(), mus_exact_zt[:, mask].max(), 100)
    ax[0, 0].scatter(x=mus_exact_zt[i, mask], y=mus_hat_zt[i, mask], label=labels[i] + r', ($\mu$, $\hat{\mu}$)')
    if i == 0: 
        ax[0, 0].plot(line, line, 'C2')
        ax[0, 0].set_title('Marginal means')
    ax[0, 0].legend()

    line = np.linspace(stds_exact_zt[:, mask].min(), stds_exact_zt[:, mask].max(), 100)
    ax[0, 1].scatter(x=stds_exact_zt[i, mask], y=stds_hat_zt[i, mask], label=labels[i] + r', ($\sigma$, $\hat{\sigma}$)')
    if i == 0: 
        ax[0, 1].plot(line, line, 'C2')        
        ax[0, 1].set_title('Marginal variances')
    ax[0, 1].legend()
    
    # DKL
    n, dkl_bins, p = ax[0, 2].hist(dkl_ratios[i, mask], bins=dkl_bins, 
                                  alpha=0.6, 
                                  label=labels[i]);
    ax[0, 2].set_title(r'$D_{KL}$ of marginals')
    ax[i, 2].set_ylabel('count')
    if i == 1: 
        ax[1, 2].set_xlabel(r'$D_{KL} / D_{KL}^{prior}$')
        ax[1, 2].hist(dkl_ratios[2, mask], bins=dkl_bins, 
                                  alpha=0.6, 
                                  label='joint')
        ax[1, 2].set_title(r'$D_{KL}$ of joint')
        
    ax[0, 2].legend(fontsize=15)
    #                 label=r'$\frac{D_{KL}(p(\theta | x)||\hat{p}(\theta | x))}{ D_{KL}(p(\theta | x)||p_{prior}(\theta))}$', 
    
    
    n, bins = np.histogram(qis[i, mask], bins=credible_intervals)
    sample_quantiles = np.cumsum(n / np.sum(n))
    theo_quantiles = np.linspace(0, 1, len(n))
    ax[1, 0].plot(theo_quantiles, sample_quantiles, 'x-', label=r'marginal ' + labels[i])
    if i == 1:
        ax[1, 0].plot(theo_quantiles, theo_quantiles)
        ax[1, 0].grid()
        ax[1, 0].set_title('Q-Q plot')
    ax[1, 0].legend()
    ax[1, 0].set_ylabel('empirical quantile')
    ax[1, 0].set_xlabel(r'quantile of $U(0, 1)$')


    
    ax[1, 1].plot(credible_intervals, ci_probs_hat[i, ], 'x-', label=r'marginal ' + labels[i])
    if i==1:
        ax[1, 1].plot(credible_intervals, credible_intervals, '-')
        ax[1, 1].set_ylabel('relative frequency')
        ax[1, 1].set_xlabel('credible interval')
        ax[1, 1].grid()
        ax[1, 1].set_title('Credible intervals')
    ax[1, 1].legend()

plt.tight_layout()

In [None]:
i = 0
fig, ax = plt.subplots(1, 2, figsize=(15, 5))
ax[0].set_title('Credible intervals')
ax[0].plot(credible_intervals, ci_probs_hat[0, ], 'x-', label=r'marginal $\hat{p}(k | x)$')
ax[0].plot(credible_intervals, ci_probs_hat[1, ], 'x-', label=r'marginal $\hat{p}(\theta | x)$')
# ax[0].plot(credible_intervals, ci_probs[0, ], 'x-', label=r'marginal $p(k | x)$')
# ax[0].plot(credible_intervals, ci_probs[1, ], 'x-', label=r'marginal $p(\theta | x)$')
ax[0].plot(credible_intervals, credible_intervals, '-', label='identity')
ax[0].legend()

theo_quantiles = 
ax[1].plot(credible_intervals, cr_probs, 'x-', label='joint')
ax[1].set_title('Joint')
ax[1].plot(credible_intervals, credible_intervals, '-', label='identity')
ax[1].legend();

plt.tight_layout()

In [None]:
fn = time_stamp + 'posterior_checks_marginals_NB_ntrain{}_nsamples{}.png'.format(int(ntrain), int(sample_size))
fig1.savefig(os.path.join('../figures', fn), dpi=300)

In [None]:
fig2, ax = plt.subplots(2, 3, figsize=(18, 10), sharex=True, sharey='col')
labels = ['k', 'theta']
for i in range(2):
    (stats, kst_p) = scipy.stats.kstest(qis[i, ], cdf='uniform')
    n, bins, patches = ax[i, 0].hist(qis[i, ], bins=20, label='K-S test, p={:1.3}'.format(kst_p))
    ax[i, 0].set_title('Posterior quantile distribution ' + labels[i])
    ax[i, 0].set_xlabel('quantile')
    ax[i, 0].set_ylabel('counts')
    ax[i, 0].legend()

    ax[i, 1].set_title('Posterior quantile Q-Q plot ' + labels[i])
    sample_quantiles = np.cumsum(n / np.sum(n))
    theo_quantiles = np.cumsum(np.diff(bins))
    
    ax[i, 1].plot(theo_quantiles, sample_quantiles, 'x-', label='empirical')
    ax[i, 1].plot(theo_quantiles, theo_quantiles, label='identity line')
    ax[i, 1].set_ylabel('empirical quantile')
    ax[i, 1].set_xlabel(r'quantile of $U(0, 1)$')
    ax[i, 1].legend()
    ax[i, 1].grid();

    ax[i, 2].set_title('Posterior credible intervals ' + labels[i])
    
    ax[i, 2].plot(credible_intervals, ci_probs[i, ], 'x-', label='empirical')
    ax[i, 2].plot(credible_intervals, credible_intervals, '-', label='identity line')
    ax[i, 2].set_ylabel('relative frequency')
    ax[i, 2].set_xlabel('credible interval')
    ax[i, 2].legend()
    ax[i, 2].grid();
    plt.tight_layout();

In [None]:
fn = time_stamp + 'posterior_checks_2_k2_{}NB.png'.format(int(k2))
fig2.savefig(os.path.join('../figures', fn), dpi=300)