In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit
import os

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import calc_all_lprob_errs

def run_basic_plots(seed, model_id, save_path, exp_id, fig_path, N):

    if model_id == 'gauss':
        from lfimodels.snl_exps.util import load_setup_gauss as load_setup
        from lfimodels.snl_exps.util import load_gt_gauss as load_gt
        from lfimodels.snl_exps.util import init_g_gauss as init_g
        from lfimodels.snl_exps.util import draw_sample_uniform_prior_33 as fast_sampler    
    elif model_id == 'lv':
        from lfimodels.snl_exps.util import load_setup_lv as load_setup
        from lfimodels.snl_exps.util import load_gt_lv as load_gt
        from lfimodels.snl_exps.util import init_g_lv as init_g
        fast_sampler = None
    elif model_id == 'mg1':
        from lfimodels.snl_exps.util import load_setup_mg1 as load_setup
        from lfimodels.snl_exps.util import load_gt_mg1 as load_gt
        from lfimodels.snl_exps.util import init_g_mg1 as init_g
        fast_sampler = None

    # simulation setup
    setup_dict = load_setup()

    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)

    logs, tds, posteriors, setup_dict = load_results(exp_id=exp_id, path=save_path)
    setup_dict = setup_dict[()]


    # lprobs figure 
    all_prop_errs = calc_all_lprob_errs(pars_true, 
                                        n_samples=N, 
                                        #posteriors=[posteriors_C[i] for i in np.arange(4,50,5)], 
                                        posteriors=posteriors, 
                                        init_g=init_g,
                                        rej=True
                                        fast_sampler=fast_sampler)
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)
    np.save(os.path.join(save_path, exp_id, 'all_prop_errs_N' + str(N)), 
            all_prop_errs)
    
    all_prop_errs_raw = calc_all_lprob_errs(pars_true, 
                                        n_samples=N, 
                                        #posteriors=[posteriors_C[i] for i in np.arange(4,50,5)], 
                                        posteriors=posteriors, 
                                        init_g=init_g,
                                        rej=False,
                                        fast_sampler=fast_sampler)
    np.save(os.path.join(save_path, exp_id, 'all_prop_errs_raw_N' + str(N)), 
            all_prop_errs_raw)
    
    plt.figure(figsize=(8,5))
    plt.semilogx(np.arange(1, setup_dict['n_rounds']+1,1) * setup_dict['n_train'], all_prop_errs, 'bd:')
    plt.semilogx(np.arange(1, setup_dict['n_rounds']+1,1) * setup_dict['n_train'], all_prop_errs_raw, 'kd:')
    plt.legend(['rej. sampling', 'naive sampling'])
    plt.xlim([0.6*setup_dict['n_train'], (setup_dict['n_rounds']+2)*setup_dict['n_train']])
    plt.xlabel('Number of simulations (log scale)')
    plt.ylabel('- log probability of true parameters')
    plt.savefig(fig_path + model_id + '_snpec_maf_n_null_' + str(setup_dict['n_null']) + '_N' + str(N) +'_lprobs.pdf')
    plt.show()


    # dists figure
    rng = np.random
    rng.seed(seed)
    res = infer.SNPEC(init_g(seed=seed),
                      obs=obs_stats,
                      n_hiddens=setup_dict['n_hiddens'],
                      seed=seed,
                      reg_lambda=setup_dict['reg_lambda'],
                      pilot_samples=setup_dict['pilot_samples'],
                      svi=setup_dict['svi'],
                      n_mades=setup_dict['n_mades'],
                      act_fun=setup_dict['act_fun'],
                      mode=setup_dict['mode'],
                      rng=rng,
                      batch_norm=setup_dict['batch_norm'],
                      verbose=setup_dict['verbose'],
                      prior_norm=setup_dict['prior_norm'])
    average = 'median'
    if average == 'mean':
        fname, avg_f = 'dist_obs', np.nanmean
    elif average == 'median':
        fname, avg_f = 'dist_obs_median', np.nanmedian
    avg_dist = []
    for xs in [trn_data[1] for trn_data in tds]:

        print('xs.shape', xs.shape)
        stats = (xs[-setup_dict['n_train']:,:] * res.stats_std)+res.stats_mean
        print('stats.shape', stats.shape)

        if model_id == 'lv':
            whiten_params = np.load('results/whiten_params_lv.npy', encoding='latin1')[()]
            obz_stats = (obs_stats.flatten() - whiten_params['means']) / whiten_params['stds']
            statz = (stats - whiten_params['means'].reshape(1,-1)) / whiten_params['stds'].reshape(1,-1)

        elif model_id == 'mg1':
            whiten_params = np.load('results/whiten_params_mg1.npy', encoding='latin1')[()]
            obz_stats = (obs_stats.flatten() - whiten_params['means']).dot(whiten_params['U'])*whiten_params['istds']
            statz = (stats - whiten_params['means']).dot(whiten_params['U'])*whiten_params['istds']
            
        elif model_id == 'gauss':
            statz = stats
            obz_stats = obs_stats
            
        dist = np.sqrt(np.sum((statz - obz_stats) ** 2, axis=1))
        avg_dist.append(avg_f(dist))
    avg_dist = np.array(avg_dist)
    np.save(os.path.join(save_path, exp_id, 'avg_dist'), avg_dist)

    plt.figure(figsize=(8,5))
    plt.plot(np.arange(1, setup_dict['n_rounds']+1,1), avg_dist, 'kd:')
    plt.xlim([0, setup_dict['n_rounds']+2])
    plt.xlabel('Number of simulations (log scale)')
    plt.ylabel(average + ' distance')
    plt.savefig(fig_path + model_id + '_snpec_maf_n_null_' + str(setup_dict['n_null']) +'_dists.pdf')
    plt.show()


In [None]:
model_id = 'gauss'
seeds = np.arange(42,52)
N = 5000

fig_path = 'results/'
save_path = 'results/' + model_id

for seed in seeds:
    exp_id = 'seed'+str(seed)        
    run_basic_plots(seed=seed, model_id=model_id, save_path=save_path, exp_id=exp_id, fig_path=fig_path, N=N)

# MMD figure (only 'Gauss' model)

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit
import os

from delfi.utils.viz import plot_pdf

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import calc_all_mmds

from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import init_g_gauss as init_g
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33

import snl.inference.diagnostics.two_sample as two_sample
from snl.util import math

def run_mmd_plots(seed, model_id, save_path, exp_id, fig_path, N):

    model_id = 'gauss'

    
    # simulation setup
    setup_dict = load_setup()

    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)

    logs, tds, posteriors, setup_dict = load_results(exp_id=exp_id, path=save_path)
    setup_dict = setup_dict[()]


    # mmd figure 
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)
    samples_true = np.load(os.path.join(save_path, exp_id, 'samples.npy'))
    all_mmds = calc_all_mmds(samples_true, 
                             n_samples=N, 
                             posteriors=posteriors, 
                             init_g=init_g,
                             rej=True)
    
    np.save(os.path.join(save_path, exp_id, 'all_mmds_N' + str(N)), 
            all_mmds)
        
    plt.figure(figsize=(8,5))
    plt.semilogx(np.arange(1, setup_dict['n_rounds']+1,1) * setup_dict['n_train'], np.sqrt(all_mmds), 'kd:')
    plt.xlim([0.6*setup_dict['n_train'], (setup_dict['n_rounds']+2)*setup_dict['n_train']])
    plt.xlabel('Number of simulations (log scale)')
    plt.ylabel('maximum mean discrepancy')
    plt.savefig(fig_path + model_id + '_snpec_maf_n_null_' + str(setup_dict['n_null']) + '_N' + str(N) +'_mmds.pdf')
    plt.show()



In [None]:
model_id = 'gauss'
seeds = np.arange(42,43)
N = 5000

fig_path = 'results/'

save_path = 'results/' + model_id

for seed in seeds:
    exp_id = 'seed'+str(seed)        
    run_mmd_plots(seed=seed, model_id=model_id, save_path=save_path, exp_id=exp_id, fig_path=fig_path, N=N)