# SLCP model (fig 2)

- fitting APT for figure 2 a,c of APT paper
- for evaluation of fits, see APT_eval.ipynb
- for plotting, see ICML_figure2.ipynb

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit

from delfi.utils.viz import plot_pdf
import delfi.inference as infer
import delfi.distribution as dd

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import init_g_gauss as init_g
from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import calc_all_lprob_errs


model_id = 'gauss'
save_path = 'results/' + model_id


seeds = np.arange(42,52)

for seed in seeds:

    exp_id = 'seed'+str(seed)

    # simulation setup
    setup_dict = load_setup()
    
    setup_dict['val_frac'] = 0.1
    
    
    setup_dict['n_null'] =  setup_dict['minibatch'] - 1
    setup_dict['n_rounds'] = 40

    pars_true, obs_stats = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)
    print('obs_stats : ', obs_stats)

    if setup_dict['train_on_all']:
        epochs=[setup_dict['epochs']//(r+1) for r in range(setup_dict['n_rounds'])]
    else:
        epochs=setup_dict['epochs']

    # control MAF seed
    rng = np.random
    rng.seed(seed)

    # generator
    g = init_g(seed=seed)

    res_C = infer.SNPEC(g,
                        obs=obs_stats,
                        n_hiddens=setup_dict['n_hiddens'],
                        seed=seed,
                        reg_lambda=setup_dict['reg_lambda'],
                        pilot_samples=setup_dict['pilot_samples'],
                        svi=setup_dict['svi'],
                        n_mades=setup_dict['n_mades'],
                        act_fun=setup_dict['act_fun'],
                        mode=setup_dict['mode'],
                        #upper=setup_dict['upper'],
                        #lower=setup_dict['lower'],
                        rng=rng,
                        batch_norm=setup_dict['batch_norm'],
                        verbose=setup_dict['verbose'],
                        prior_norm=setup_dict['prior_norm'])
    
    print('model class :', res_C.network)


    # train
    t = timeit.time.time()

    print('fitting model with SNPC-C')
    logs_C, tds_C, posteriors_C = res_C.run(
                        n_train=setup_dict['n_train'],
                        proposal=setup_dict['proposal'],
                        moo=setup_dict['moo'],
                        n_null = setup_dict['n_null'],
                        n_rounds=setup_dict['n_rounds'],
                        train_on_all=setup_dict['train_on_all'],
                        minibatch=setup_dict['minibatch'],
                        val_frac=setup_dict['val_frac'],
                        silent_fail=False,
                        epochs=epochs)

    print('fitting time : ', timeit.time.time() - t)

    save_results(logs=logs_C, tds=tds_C, posteriors=posteriors_C, 
                 setup_dict=setup_dict, exp_id=exp_id, path=save_path)

    #logs, tds, posteriors, setup_dict = load_results(exp_id=exp_id, path=path)