# computation of evaluation metrics for SNL

- computtes MMDs for SNL fits in figure 3 (SLCP with added noise)
- MMDs (and neg. lob-probs, median distances) for figures 2, 4, 8 computes with the SNL package (python 2)

# MMD figure (only 'Gauss' model)

In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit
import os
import pickle
from delfi.utils.viz import plot_pdf
from delfi.utils.delfi2snl import SNLprior

from lfimodels.snl_exps.util import save_results, load_results

from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import init_g_gauss as init_g


import snl.inference.mcmc as mcmc
import snl.inference.diagnostics.two_sample as two_sample
from snl.util import math

def sampl_snl(save_path, exp_id, rng, prior, obs_stats, init_pars, N=5000):

    fn = os.path.join(save_path, exp_id, 'SNL_posteriors')
    with open(fn + '.pkl', 'rb') as f:
        snl_posteriors = pickle.load(f)

    all_ps = []

    r = 1
    
    try:
        all_ps = np.load(os.path.join(save_path, exp_id, 'ps.npy'))
        assert np.all( all_ps[r].shape[0] == N for r in range(len(all_ps)) )
        print('loaded results successfully')
    except:
        print('failed to load results, re-sampling !')
        for network in snl_posteriors:

            print('\n')
            print('round r=' + str(r))
            print('\n')

            log_post = lambda t: network.eval([t, obs_stats]) + prior.eval(t)
            sampler = mcmc.SliceSampler(init_pars, log_post, thin=10)
            sampler.gen(100, rng=rng)  # burn in
            samples = sampler.gen(N, rng=rng)

            all_ps.append(samples)

            r += 1

    fn = os.path.join(save_path, exp_id, 'SNL_samples')
    np.save(fn, all_ps)

    return all_ps


def calc_all_mmds_snl(samples_true, samples_snl):
    """ only called for 'Gaussian' simulator """

    all_mmds = []
    ct = 0
    for samples in samples_snl:
        
        ct += 1
        
        print('\n round #' + str(ct) + '/' + str(len(samples_snl)))
        
        if np.any(np.isnan(samples)): # fail to sample n_sample times
            all_mmds.append(np.inf)
        else:            
            print('- computing MMD')
            scale = math.median_distance(samples_true)
            mmd = two_sample.sq_maximum_mean_discrepancy(samples, samples_true, scale=scale)
            if isinstance(mmd, np.ndarray):
                mmd = mmd.flatten()[0]
            all_mmds.append(mmd)
        
    return np.array(all_mmds).flatten()


def run_mmd_plots_snl(seed, model_id, save_path, exp_id, fig_path, N):

    model_id = 'gauss'
    
    # simulation setup
    setup_dict = load_setup()

    pars_true, _ = load_gt(generator=init_g(seed=seed))
    print('pars_true : ', pars_true)

    #_,_,posteriors, setup_dict = load_results(exp_id=exp_id, path=save_path)
    
    #obs_stats = posteriors[0].obs_stats # get full obs_stats (including noise) from parallel SNPE-C fits
    fn = os.path.join(save_path, exp_id, 'obs_stats_noise')
    obs_stats = np.load(fn + '.npy')
    
    print('obs_stats : ', obs_stats)

    prior = SNLprior(init_g(seed=seed).prior)

    rng = np.random
    rng.seed(seed)
    
    samples_snl = sampl_snl(save_path, exp_id, rng=rng, prior=prior, obs_stats=obs_stats.flatten(),init_pars=pars_true, N=N)    
    
    # mmd figure 
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)
    samples_true = np.load(os.path.join('results/' + model_id, 'seed' + str(seed-10), 'samples.npy'))
    
    all_mmds = calc_all_mmds_snl(samples_true, samples_snl)
    
    print('all_mmds', all_mmds)
    
    np.save(os.path.join(save_path, exp_id, 'all_mmds_snl_N' + str(N)), 
            all_mmds)

    try:
        plt.figure(figsize=(8,5))
        plt.semilogx(np.arange(1, all_mmds.size+1,1) * setup_dict['n_train'], np.sqrt(all_mmds), 'kd:')
        plt.xlim([0.6*setup_dict['n_train'], (setup_dict['n_rounds']+2)*setup_dict['n_train']])
        plt.xlabel('Number of simulations (log scale)')
        plt.ylabel('maximum mean discrepancy')
        #plt.savefig(fig_path + model_id + '_maf_snl' + '_N' + str(N) +'_mmds.pdf')
        plt.show()
    except:
        pass


In [None]:
model_id = 'gauss'
N = 1000

fig_path = 'results/'
save_path = 'results/' + model_id + '_noisedims_v1'
#save_path = 'results/' + model_id + '_noisedims_v2'
#save_path = 'results/' + model_id + '_noisedims_v3'
#save_path = 'results/' + model_id + '_noisedims_v4'

seeds = np.arange(52,62)
for seed in seeds:
    exp_id = 'seed'+str(seed)        
    run_mmd_plots_snl(seed=seed, model_id=model_id, save_path=save_path, exp_id=exp_id, fig_path=fig_path, N=N)