In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import os
import sys
import pickle

import snl.inference.nde as nde
from snl.ml.models.mafs import ConditionalMaskedAutoregressiveFlow
from delfi.utils.delfi2snl import SNLprior, SNLmodel

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from snl.pdfs import gaussian_kde

from lfimodels.snl_exps.util import save_results, load_results

# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 #angle= np.pi/4.0,  # rotation angle in radians
                 #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

seed=42
g = init_g(seed=seed)

obs_stats = np.array([[0., 0.]])

verbose =True
        
setup_dict = {}
setup_dict['seed'] = seed
setup_dict['obs_stats'] = obs_stats

# training schedule
setup_dict['n_rounds'] = 10
setup_dict['n_train'] = 1000

# fitting setup
setup_dict['minibatch'] = 100
setup_dict['reg_lambda'] = 0.001
setup_dict['pilot_samples'] = 0
setup_dict['prior_norm'] = False
setup_dict['init_norm'] = False

exp_id = 'seed' + str(setup_dict['seed'])
save_path = 'results/two_moons_runs/'    
    


In [None]:


seeds = np.arange(42,52)
    
    
xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

    
algs=['_MDN_SNPEA', '_MDN_SNPEB', '_continuous_MDN_SNPEC','_discrete_MAF_SNPEC', '_MAF_SNL', 'ground-truth']

for seed in seeds:    
        
    plt.figure(figsize=(8, 12))
        
    exp_id='seed'+str(seed)        
    
    for k in range(len(algs)):
        
        alg = algs[k]
        # Gaussian-proposal SNPE-C
        
        if k < 4:
            _,_, posteriors, _ = load_results(exp_id=exp_id, path=save_path + alg)

            plt.subplot(3,2,k+1)
            posterior = posteriors[-1]
            if posterior is None:
                pass
            else:

                if not posterior is None:
                    pp = posterior.eval(xy, log=False).reshape(list(X.shape))
                    plt.imshow(pp.T, origin='lower',
                               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                               aspect='auto', interpolation='none')
                else:
                    plt.text(-0.1, 0., 'broke', color='w')
                    plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                               aspect='auto', interpolation='none')    

            plt.title(alg)
            plt.xticks([])
            plt.yticks([])
            
        elif k == 4:

            plt.subplot(3,2,k+1)            
            ps = np.load(os.path.join(save_path, exp_id, 'ps.npy'))                  
            kde = gaussian_kde(xs=ps[-1], std=0.01)
            kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(ps[-1].shape[0])], 
                         a=1./ps[-1].shape[0] * np.ones(ps[-1].shape[0]))
            pp = kde.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    
            plt.title(alg)
            plt.xticks([])
            plt.yticks([])
            
        elif k == 5:
            
            plt.subplot(3,2,k+1)
            samples_gt = np.load('results/two_moons_runs/5k_gtposterior_2moons.npy')
            kde = gaussian_kde(xs=samples_gt, std=0.01)
            kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(samples_gt.shape[0])], 
                         a=1./samples_gt.shape[0] * np.ones(samples_gt.shape[0]))
            pp = kde.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    
            plt.title(alg)
            plt.xticks([])
            plt.yticks([])                        
            
    plt.suptitle('seed #' + str(seed))
    plt.show()



In [None]:
%matplotlib inline
import numpy as np
from matplotlib import pyplot as plt
import timeit
import os

from delfi.utils.viz import plot_pdf

from lfimodels.snl_exps.util import save_results, load_results
from lfimodels.snl_exps.util import calc_all_mmds

from lfimodels.snl_exps.util import load_setup_gauss as load_setup
from lfimodels.snl_exps.util import load_gt_gauss as load_gt
from lfimodels.snl_exps.util import init_g_gauss as init_g
from lfimodels.snl_exps.util import draw_sample_uniform_prior_33

def run_mmd_plots(seed, model_id, save_path, exp_id, fig_path, N):

    
    # mmd figure 
    dir = os.path.join(save_path, exp_id)
    if not os.path.exists(dir):
        os.makedirs(dir)
        
    _,_, posteriors, _ = load_results(exp_id=exp_id, path=save_path)
        

    samples_true = np.load('results/two_moons_runs/5k_gtposterior_2moons.npy')
    all_mmds = calc_all_mmds(samples_true, 
                             n_samples=N, 
                             posteriors=posteriors, 
                             init_g=None)
    
    print('all_mmds', all_mmds)
    
    np.save(os.path.join(save_path, exp_id, 'all_mmds_N' + str(N)), 
            all_mmds)
        
    plt.figure(figsize=(8,5))
    plt.semilogx(np.arange(1, all_mmds.size+1,1) * 1000, np.sqrt(all_mmds), 'kd:')
    plt.xlim([0.6*1000, 12000])
    plt.xlabel('Number of simulations (log scale)')
    plt.ylabel('maximum mean discrepancy')
    plt.show()



In [None]:
model_id = 'two_moons_runs/'
seeds = np.arange(42,52)
N = 1000

fig_path = 'results/'
save_path = 'results/' + model_id + '_MDN_SNPEA'

for seed in seeds:
    exp_id = 'seed'+str(seed)        
    run_mmd_plots(seed=seed, model_id=model_id, save_path=save_path, exp_id=exp_id, fig_path=fig_path, N=N)