# Gaussian model

- simulator taken from https://github.com/mackelab/SNL_py3port, which contains the original https://github.com/gpapamak/snl after 2to3 conversion with minimal edits (deactivating generator-internal summary stats normalization).

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import init_g_gauss as init_g
from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import calc_all_lprob_errs


model_id = 'gauss'
save_path = 'results/' + model_id


seeds = np.arange(42,52)

for seed in seeds:

    exp_id = 'seed'+str(seed)

    # simulation setup
    setup_dict = load_setup()
    
    
    
    setup_dict['n_null'] =  setup_dict['minibatch'] - 1
    setup_dict['n_rounds'] = 40

    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)

    if setup_dict['train_on_all']:
        epochs=[setup_dict['epochs']//(r+1) for r in range(setup_dict['n_rounds'])]
    else:
        epochs=setup_dict['epochs']

    # control MAF seed
    rng = np.random
    rng.seed(seed)

    # generator
    g = init_g(seed=seed)

    res_C = infer.SNPEC(g,
                        obs=obs_stats,
                        n_hiddens=setup_dict['n_hiddens'],
                        seed=seed,
                        reg_lambda=setup_dict['reg_lambda'],
                        pilot_samples=setup_dict['pilot_samples'],
                        svi=setup_dict['svi'],
                        n_mades=setup_dict['n_mades'],
                        act_fun=setup_dict['act_fun'],
                        mode=setup_dict['mode'],
                        #upper=setup_dict['upper'],
                        #lower=setup_dict['lower'],
                        rng=rng,
                        batch_norm=setup_dict['batch_norm'],
                        verbose=setup_dict['verbose'],
                        prior_norm=setup_dict['prior_norm'])
    
    print('model class :', res_C.network)


    # train
    t = timeit.time.time()

    print('fitting model with SNPC-C')
    logs_C, tds_C, posteriors_C = res_C.run(
                        n_train=setup_dict['n_train'],
                        proposal=setup_dict['proposal'],
                        moo=setup_dict['moo'],
                        n_null = setup_dict['n_null'],
                        n_rounds=setup_dict['n_rounds'],
                        train_on_all=setup_dict['train_on_all'],
                        minibatch=setup_dict['minibatch'],
                        epochs=epochs)

    print('fitting time : ', timeit.time.time() - t)

    save_results(logs=logs_C, tds=tds_C, posteriors=posteriors_C, 
                 setup_dict=setup_dict, exp_id=exp_id, path=save_path)

    #logs, tds, posteriors, setup_dict = load_results(exp_id=exp_id, path=path)

In [None]:
import numpy as np
from matplotlib import pyplot as plt
import timeit
import os

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import calc_all_lprob_errs

import snl.inference.diagnostics.two_sample as two_sample
from snl.util import math

fig_path = 'results/'


def draw_sample_uniform_prior_33(post, n_samples, batch=None):
    
    batch = n_samples if batch is None else None
    n_drawn, samples, ct = 0, [], 0
    while n_drawn < n_samples:
        minibatch = post.gen(batch)
        idx = np.where(np.prod(np.abs(minibatch)<3.,axis=1))[0]
        samples.append(minibatch[idx])
        n_drawn += idx.size
        
        ct += 1        
    print('sampling, (itercount, n_drawn) = ', (ct,n_drawn))
    return np.concatenate(samples, axis=0)[:n_samples]


def calc_all_mmds(samples_true, n_samples, posteriors, init_g, rej=True):

    all_mmds = []
    ct = 0
    for proposal in posteriors:
        
        ct += 1
        
        print('\n round #' + str(ct) + '/' + str(len(posteriors)))
        print('- sampling')
        if rej:
            samples = draw_sample_uniform_prior_33(proposal, n_samples)
        else:
            samples = proposal.gen(n_samples)

        print('- computing MMD')
        scale = math.median_distance(samples_true)
        mmd = two_sample.sq_maximum_mean_discrepancy(samples, samples_true, scale=scale).flatten()
        all_mmds.append(mmd)
        

    return np.array(all_mmds).flatten()



def run_mmd_plots(seed, model_id, save_path, exp_id, fig_path, N):

    model_id = 'gauss'
    from lfimodels.snl_exps.util import load_setup_gauss as load_setup
    from lfimodels.snl_exps.util import load_gt_gauss as load_gt
    from lfimodels.snl_exps.util import init_g_gauss as init_g
    
    # simulation setup
    setup_dict = load_setup()

    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)

    logs, tds, posteriors, setup_dict = load_results(exp_id=exp_id, path=save_path)
    setup_dict = setup_dict[()]


    # mmd figure 
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)
    samples_true = np.load(os.path.join('results/gauss', exp_id, 'samples.npy'))
    all_mmds = calc_all_mmds(samples_true, 
                             n_samples=N, 
                             posteriors=posteriors, 
                             init_g=init_g,
                             rej=True)
    
    np.save(os.path.join(save_path, exp_id, 'all_mmds_N' + str(N)), 
            all_mmds)
        
    try:

        plt.figure(figsize=(8,5))
        plt.semilogx(np.arange(1, all_mmds.size+1,1) * setup_dict['n_train'], all_mmds, 'kd:')
        plt.xlim([0.6*setup_dict['n_train'], (setup_dict['n_rounds']+2)*setup_dict['n_train']])
        plt.xlabel('Number of simulations (log scale)')
        plt.ylabel('maximum mean discrepancy')
        #plt.savefig(fig_path + model_id + '_snpec_maf_n_null_' + str(setup_dict['n_null']) + '_N' + str(N) +'_mmds.pdf')
        plt.show()
    except:
        print('plotting failed')


In [None]:
model_id = 'gauss'
seeds = np.arange(42,52)
N = 5000

save_path = 'results/' + model_id
for seed in seeds:
    exp_id = 'seed'+str(seed)        
    run_mmd_plots(seed=seed, model_id=model_id, save_path=save_path, exp_id=exp_id, fig_path=fig_path, N=N)