In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results

seed=46
return_abs = False

# analytic target posterior for plotting (y-axis flipped for overlay with imshow figures !)
p_true = dd.MoG(a=[0.5, 0.5], 
                ms=[np.asarray([.5, -.5]), np.asarray([-.5, .5])], 
                Ss=[0.01*np.eye(2), 0.01*np.eye(2)])
p_true.ndim=2

# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 #angle= np.pi/4.0,  # rotation angle in radians
                 #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[0., 0.]])

verbose =True

trn_data = g.gen(1000)
plt.figure(figsize=(6,5))
plt.plot(trn_data[1], trn_data[0], '.')
plt.xlabel('x')
plt.ylabel('theta(1) resp. theta(2)')
plt.title('marginals p(theta_i, x)')
plt.show()
        
setup_dict = {}
setup_dict['seed'] = seed
setup_dict['obs_stats'] = obs_stats

# training schedule
setup_dict['n_rounds'] = 10
setup_dict['n_train'] = 1000

# fitting setup
setup_dict['minibatch'] = 100
setup_dict['reg_lambda'] = 0.001
setup_dict['pilot_samples'] = 0
setup_dict['prior_norm'] = False
setup_dict['init_norm'] = False

exp_id = 'seed' + str(setup_dict['seed'])
save_path = 'results/two_moons_runs/'    


# discrete-proposal SNPE-C (MAF)

In [None]:
# MAF parameters
setup_dict_  = setup_dict.copy()
setup_dict_['proposal'] = 'discrete'
setup_dict_['moo'] = 'resample'
setup_dict_['mode'] = 'random'
setup_dict_['n_hiddens'] = [50,50]
setup_dict_['n_mades'] = 5
setup_dict_['act_fun'] = 'tanh'
setup_dict_['batch_norm'] = False # batch-normalization currently not supported
setup_dict_['train_on_all'] = True
setup_dict_['svi'] = False
setup_dict_['val_frac'] = 0.1
setup_dict_['n_null'] = setup_dict_['minibatch']-1

# control MAF seed
rng = np.random
rng.seed(seed)

# generator
g = init_g(seed=seed)

setup_dict_['epochs'] = 1000
if setup_dict_['train_on_all']:
    epochs=[setup_dict_['epochs'] // (r+1) for r in range(setup_dict_['n_rounds'])]
else:
    epochs=setup_dict_['epochs']

# inference object
res = infer.SNPEC(g, 
                obs=obs_stats, 
                n_hiddens=setup_dict_['n_hiddens'],
                seed=seed,
                reg_lambda=setup_dict_['reg_lambda'],
                pilot_samples=setup_dict_['pilot_samples'],
                svi=setup_dict_['svi'],
                n_mades=setup_dict_['n_mades'],
                act_fun=setup_dict_['act_fun'],
                mode=setup_dict_['mode'],
                rng=rng,
                batch_norm=setup_dict_['batch_norm'],
                verbose=verbose,
                prior_norm=setup_dict_['prior_norm'])

# train
t = timeit.time.time()

logs, tds, posteriors = res.run(
                    n_train=setup_dict_['n_train'],
                    proposal=setup_dict_['proposal'],
                    moo=setup_dict_['moo'],
                    n_null = setup_dict_['n_null'],
                    n_rounds=setup_dict_['n_rounds'],
                    train_on_all=setup_dict_['train_on_all'],
                    minibatch=setup_dict_['minibatch'],
                    val_frac=setup_dict_['val_frac'],
                    epochs=epochs)

print(timeit.time.time() -  t)

save_results(logs=logs, tds=tds, posteriors=posteriors, 
             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_discrete_MAF_SNPEC')


In [None]:
from delfi.utils.viz import probs2contours

n_rounds = setup_dict_['n_rounds']

# plot loss curves
plt.figure(figsize=(12, 8))
for r in range(n_rounds):
    plt.subplot(np.ceil(n_rounds/3), 3, r+1)
    plt.plot(logs[r]['loss'])
    plt.title('loss for round r=' + str(r+1))
plt.show()

# plot posterior estimates (overlaid with ground-truth)
xo = 1.*obs_stats.flatten()
lims = np.array([[-1,1], [-1,1]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
plt.figure(figsize=(16, 12))
plt.subplot(np.ceil(n_rounds/3), 3, 1)
pp = g.prior.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(0*pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
pp = p_true.eval(xy, log=False).reshape(list(X.shape))
for r in range(n_rounds):
    plt.subplot(np.ceil(n_rounds/3), 3, r + 1)
    posterior = posteriors[r] 
    pp = posterior.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
    plt.title('posterior estimate after round r='+str(r+1))

plt.show()

In [None]:
"""
seeds = np.arange(42,52)
for seed in seeds:

    exp_id = 'seed' + str(seed)
    _,_,posteriors, _ = load_results(exp_id=exp_id, path=save_path + '_discrete_MAF_SNPEC')
    posterior = posteriors[-1]
    
    plt.figure(figsize=(16,16))

    lims = np.array([[-0.5, 0.5], [-0.5, 0.5]])
    i,j,resolution = 0,1,100
    xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
    yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
    X, Y = np.meshgrid(xx, yy)
    xy = np.concatenate(
        [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
    pp = posterior.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[-.5, 0.5, -0.5, 0.5],
               aspect='auto', interpolation='none')
    plt.title('posterior estimate after round r='+str(r+1))

    plt.show()
"""

# SNL (MAF)

In [None]:
import sys
import snl.inference.nde as nde
from snl.ml.models.mafs import ConditionalMaskedAutoregressiveFlow
from delfi.utils.delfi2snl import SNLprior, SNLmodel

# MAF parameters
setup_dict_  = setup_dict.copy()
setup_dict_['mode'] = 'random'
setup_dict_['n_hiddens'] = [50,50]
setup_dict_['n_mades'] = 5
setup_dict_['act_fun'] = 'tanh'
setup_dict_['batch_norm'] = False # batch-normalization currently not supported
setup_dict_['train_on_all'] = True
setup_dict_['thin'] = 10
# control MAF seed
rng = np.random
rng.seed(seed)

# explicit call to MAF constructor
theta, x = g.gen(1)
n_inputs, n_outputs  = x.size, theta.size
model = ConditionalMaskedAutoregressiveFlow(
                n_inputs=n_inputs,
                n_outputs=n_outputs,
                n_hiddens=setup_dict_['n_hiddens'],
                act_fun=setup_dict_['act_fun'],
                n_mades=setup_dict_['n_mades'],
                mode=setup_dict_['mode'],
                rng=rng)

# generator
g = init_g(seed=seed)

# inference object
inf = nde.SequentialNeuralLikelihood(SNLprior(g.prior),               # method to draw parameters  
                                     SNLmodel(g.model, g.summary).gen # method to draw summary stats
                                    )

# train
t = timeit.time.time()

rng = np.random # control  
rng.seed(seed)  # MCMC seed
model = inf.learn_likelihood(obs_stats.flatten(), model, n_samples=setup_dict_['n_train'], 
                             n_rounds=setup_dict_['n_rounds'], 
                             train_on_all=setup_dict_['train_on_all'], thin=setup_dict_['thin'], save_models=False, 
                             logger=sys.stdout, rng=rng)

print(timeit.time.time() -  t)

tds = zip(inf.all_ps, inf.all_xs)

save_results(logs=[], tds=tds, posteriors=[model], 
             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_MAF_SNL')

### visualize learned likelihood

In [None]:
import snl.inference.mcmc as mcmc
log_posterior = lambda t: model.eval([t, obs_stats.flatten()]) + inf.prior.eval(t)
sampler = mcmc.SliceSampler(x=inf.all_ps[-3][-1], lp_f=log_posterior, thin=10)
ps = sampler.gen(1000)

In [None]:
xo = 1.*obs_stats.flatten()

lims = np.array([[-1,1], [-1,1]])
i,j,resolution = 0,1, 100
        
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
pp = model.eval((xo, xy), log=False).reshape(list(X.shape))

plt.imshow(pp.T, origin='lower',
           extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
           aspect='auto', interpolation='none')
#plt.savefig('/home/mackelab/Desktop/SNL_5rounds_final_likelihood_tightll.pdf')
plt.show()

In [None]:
plt.figure(figsize=(16,11))
for r in range(n_rounds-1):
    plt.subplot(np.ceil(n_rounds/3+1), 3, r + 1)
    plt.plot(inf.all_ps[r][:,0],
             inf.all_ps[r][:,1], 'k.')
    plt.axis([-1,1,-1,1])
    plt.xlabel('theta1')
    plt.xlabel('theta2')
    plt.title('round r='+str(r))

plt.subplot(np.ceil(n_rounds/3+1), 3, n_rounds+1)
plt.plot(ps[:,0],
         ps[:,1], 'k.')
plt.axis([-1,1,-1,1])
plt.xlabel('theta1')
plt.xlabel('theta2')
plt.title('round r='+str(n_rounds))
    
plt.show()

# SNPE A (always MDN)

In [None]:
setup_dict_  = setup_dict.copy()
setup_dict_['n_components'] = 20
setup_dict_['n_hiddens'] = [50,50]
setup_dict_['svi'] = False
setup_dict_['val_frac'] = 0.1
setup_dict_['epochs'] = 500

# generator
g = init_g(seed=seed)

# inference object
res_A = infer.CDELFI(g, 
                 obs=obs_stats, 
                 n_hiddens=setup_dict_['n_hiddens'], 
                 n_components=setup_dict_['n_components'],
                 seed=seed, 
                 reg_lambda=setup_dict_['reg_lambda'],
                 pilot_samples=setup_dict_['pilot_samples'],
                 svi=setup_dict_['svi'],
                 verbose=verbose,
                 init_norm=setup_dict_['init_norm'],
                 prior_norm=setup_dict_['prior_norm'])

# train
t = timeit.time.time()

logs_A, tds_A, posteriors_A = res_A.run(
                    n_train=setup_dict_['n_train'], 
                    n_rounds=setup_dict_['n_rounds'], 
                    val_frac=setup_dict_['val_frac'],
                    minibatch=setup_dict_['minibatch'], 
                    epochs=setup_dict_['epochs'])

print(timeit.time.time() -  t)

save_results(logs=logs_A, tds=tds_A, posteriors=posteriors_A, 
             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_MDN_SNPEA')


In [None]:

xo = 1.*obs_stats.flatten()
lims = np.array([[-.8,.8], [-.8,.8]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

for r in range(len(posteriors_A)):

    posterior = posteriors_A[r]
    if posterior is None:
        pass
    else:
        plt.figure(figsize=(8, 8))
    
        #posterior = posteriors_A[r] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')

        else:
            plt.text(-0.1, 0., 'broke', color='w')
            plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    

        plt.ylabel('SNPE-A')
        plt.xticks([])
        plt.yticks([])
        plt.title('N = ' + str( (r+1)*setup_dict_['n_train'] ))
        plt.show()
        

# SNPE B (always MDN)

In [None]:
setup_dict_ = setup_dict.copy()
setup_dict_['n_components'] = 20
setup_dict_['n_hiddens'] = [50,50]
setup_dict_['svi'] = True
setup_dict_['epochs'] = 500
setup_dict_['round_cl'] = 1

# generator
g = init_g(seed=seed)

# inference object
res_B = infer.SNPE(g, 
                 obs=obs_stats, 
                 n_hiddens=setup_dict_['n_hiddens'], 
                 n_components=setup_dict_['n_components'],
                 seed=seed, 
                 reg_lambda=setup_dict_['reg_lambda'],
                 pilot_samples=setup_dict_['pilot_samples'],
                 svi=setup_dict_['svi'],
                 verbose=verbose,
                 init_norm=setup_dict_['init_norm'],
                 prior_norm=setup_dict_['prior_norm'])

# train
t = timeit.time.time()

logs_B, tds_B, posteriors_B = res_B.run(
                    n_train=setup_dict_['n_train'], 
                    n_rounds=setup_dict_['n_rounds'], 
                    minibatch=setup_dict_['minibatch'], 
                    round_cl=setup_dict_['round_cl'], 
                    epochs=setup_dict_['epochs'])

print(timeit.time.time() -  t)

save_results(logs=logs_B, tds=tds_B, posteriors=posteriors_B, 
             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_MDN_SNPEB')


In [None]:

xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

for r in range(len(logs_B)):

    posterior = posteriors_B[r]
    if posterior is None:
        pass
    else:
        plt.figure(figsize=(8, 8))
    
        #posterior = posteriors_A[r] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')

        else:
            plt.text(-0.1, 0., 'broke', color='w')
            plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    

        plt.ylabel('SNPE-B')
        plt.xticks([])
        plt.yticks([])
        plt.title('N = ' + str( (r+1)*setup_dict_['n_train'] ))
        plt.show()
        

# discrete-proposal SNPE C (MDN)

In [None]:
"""
setup_dict_ = setup_dict.copy()
setup_dict_['proposal'] = 'discrete'
setup_dict_['moo'] = 'resample'
setup_dict_['n_components'] =20
setup_dict_['n_hiddens'] = [50,50]
setup_dict_['train_on_all'] = True
setup_dict_['svi'] = False
setup_dict_['n_null'] = setup_dict_['minibatch']-1
setup_dict_['epochs'] = 5000
setup_dict_['val_frac'] = 0.1

# generator
g = init_g(seed=seed)

# inference object
res_dC = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=setup_dict_['n_hiddens'], 
                 n_components=setup_dict_['n_components'],
                 seed=seed, 
                 reg_lambda=setup_dict_['reg_lambda'],
                 pilot_samples=setup_dict_['pilot_samples'],
                 svi=setup_dict_['svi'],
                 verbose=verbose,
                 prior_norm=setup_dict_['prior_norm'])

# train
t = timeit.time.time()

logs_dC, tds_dC, posteriors_dC = res_dC.run(
                    n_train=setup_dict_['n_train'], 
                    proposal=setup_dict_['proposal'],
                    moo=setup_dict_['moo'],
                    n_null = setup_dict_['n_null'],
                    n_rounds=setup_dict_['n_rounds'], 
                    minibatch=setup_dict_['minibatch'], 
                    epochs=setup_dict_['epochs'],
                    silent_fail=False,
                    val_frac=setup_dict_['val_frac'],
                    verbose=True)

print(timeit.time.time() -  t)

save_results(logs=logs_dC, tds=tds_dC, posteriors=posteriors_dC, 
             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_discrete_MDN_SNPEC')
"""

In [None]:
"""
xo = 1.*obs_stats.flatten()
lims = np.array([[-1,1], [-1,1]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

for r in range(len(posteriors_dC)):

    posterior = posteriors_dC[r]
    if posterior is None:
        pass
    else:
        plt.figure(figsize=(8, 8))
    
        #posterior = posteriors_A[r] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')

        else:
            plt.text(-0.1, 0., 'broke', color='w')
            plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    

        plt.ylabel('SNPE-C')
        plt.xticks([])
        plt.yticks([])
        plt.title('N = ' + str( (r+1)*setup_dict_['n_train'] ))
        plt.show()
"""

# Gaussian-proposal SNPE-C (MDN)

In [None]:
setup_dict_ = setup_dict.copy()
setup_dict_['proposal'] = 'gaussian'
setup_dict_['n_hiddens'] = [50,50]
setup_dict_['n_components'] = 20
setup_dict_['train_on_all'] = True
setup_dict_['svi'] = False
setup_dict_['n_null'] = setup_dict_['minibatch']-1
setup_dict_['epochs'] = 5000
setup_dict_['val_frac'] = 0.1

# generator
g = init_g(seed=seed)

# inference object
res_gC = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=setup_dict_['n_hiddens'], 
                 n_components=setup_dict_['n_components'],
                 seed=seed, 
                 reg_lambda=setup_dict_['reg_lambda'],
                 pilot_samples=setup_dict_['pilot_samples'],
                 svi=setup_dict_['svi'],
                 verbose=verbose,
                 prior_norm=setup_dict_['prior_norm'])

# train
t = timeit.time.time()

logs_gC, tds_gC, posteriors_gC = res_gC.run(
                    n_train=setup_dict_['n_train'], 
                    proposal='gaussian',
                    n_rounds=setup_dict_['n_rounds'], 
                    minibatch=setup_dict_['minibatch'], 
                    epochs=setup_dict_['epochs'],
                    val_frac=setup_dict_['val_frac'],
                    train_on_all=setup_dict_['train_on_all'],    
                    verbose=True)

print(timeit.time.time() -  t)

save_results(logs=logs_gC, tds=tds_gC, posteriors=posteriors_gC, 
             setup_dict=setup_dict_, exp_id=exp_id, path=save_path + '_continuous_MDN_SNPEC')


In [None]:

xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

for r in range(10):

    posterior = posteriors_gC[r]
    if posterior is None:
        pass
    else:
        plt.figure(figsize=(8, 8))
    
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')

        else:
            plt.text(-0.1, 0., 'broke', color='w')
            plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    

        plt.ylabel('continuous-proposal SNPE-C')
        plt.xticks([])
        plt.yticks([])
        plt.title('N = ' + str( (r+1)*setup_dict_['n_train'] ))
        plt.show()
        

# MoG-proposal SNPE-C (MDN)

In [None]:
"""
n_rounds = 10
epochs=[500//(r+1) for r in range(n_rounds)]

# generator
g = init_g(seed=seed)

# inference object
res_mC = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_mdn, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_mC, tds_mC, posteriors_mC = res_mC.run(n_train=n_train, 
                    proposal='mog',
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=np.max(epochs))

print(timeit.time.time() -  t)
"""

# ground-truth

In [None]:
m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
             #angle= np.pi/4.0,  # rotation angle in radians
             #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
             seed=seed)
p = dd.Uniform(lower=[-0.45,-0.45], upper=[0.45,0.45], seed=seed)
s = ds.Identity()
ggt = dg.Default(model=m, prior=p, summary=s)

p_all = []
for i in range(10):
    params, stats = ggt.gen(1000000)

    dists = np.sum((obs_stats - stats)**2, axis=1)
    idx = np.argsort(dists)
    plt.hist( dists )
    plt.show()

    plt.plot(params[idx[:100], 0], params[idx[:100], 1], 'k.')
    plt.show()
    
    print(np.max(dists[idx[:100]]))
    p_all.append(params[idx[:100]])
    
p_all = np.vstack(p_all)
np.save(save_path+ 'ABC', p_all)

# assemble figure

In [None]:
# seed 46 has particularly nice MAF SNPE-C posteriors !

#logs_gC, tds_gC, posteriors_gC, setup_dict = load_results(
#    exp_id='continuousMDN_SNPEC_seed43', 
#    path='results/two_moons_runs/validationset')


In [None]:
# seed 43 worked quite nice for MDN SNPE-C !

#logs_gC, tds_gC, posteriors_gC, setup_dict = load_results(
#    exp_id='continuousMDN_SNPEC_seed43', 
#    path='results/two_moons_runs/validationset')


In [None]:
from snl.pdfs import gaussian_kde

plt.figure(figsize=(9, 12), frameon=False)
plt.subplot(2,3,1)


xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

"""
plt.subplot(7, 3, 3)
pp = p_true.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
plt.ylabel('real posterior')
plt.xticks([])
plt.yticks([])
"""

idx_r = [0,4,9]

for r in range(len(idx_r)):
    plt.subplot(5, 4, 1 +r)
    try:
        posterior = posteriors_A[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-A')
    plt.xticks([])
    plt.yticks([])
    plt.title('N = ' + str( (idx_r[r]+1)*setup_dict['n_train']))



for r in range(len(idx_r)):
    plt.subplot(5, 4, 5 +r)
    try:
        posterior = posteriors_B[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-B')
    plt.xticks([])
    plt.yticks([])


for r in range(len(idx_r)):
    plt.subplot(5, 4, 9 +r)
    try:
        posterior = posteriors_gC[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('MDN SNPE-C')
    plt.xticks([])
    plt.yticks([])
    

"""
for r in range(len(idx_r)):
    plt.subplot(7, 4, 17 +r)
    try:
        posterior = posteriors_dC[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('discrete-trained MoG SNPE-C')
    plt.xticks([])
    plt.yticks([])
"""

for r in range(len(idx_r)):
    plt.subplot(5, 4, 13 +r)
    try:
        posterior = posteriors[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass        
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('MAF SNPE-C')
    plt.xticks([])
    plt.yticks([])
    
for r in range(len(idx_r)):
    plt.subplot(5, 4, 17 +r)
    if idx_r[r]+1 < len(inf.all_ps) :
        plt.plot(inf.all_ps[idx_r[r]+1][:,0],
                 inf.all_ps[idx_r[r]+1][:,1], 'k.')
    else:     
        plt.plot(ps[:,0],
                 ps[:,1], 'k.')
    plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])
        
    if r == 0:
        plt.ylabel('SNL')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    plt.xticks([])
    plt.yticks([])
    

    plt.subplot(5,4,12)
    i,j = 0,1
    kde = gaussian_kde(xs=p_all, std=0.01)
    kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(p_all.shape[0])], 
                 a=1./p_all.shape[0] * np.ones(p_all.shape[0]))
    pp = kde.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
    plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])    
    plt.xticks([])
    plt.yticks([])
    plt.title('N = 1e6')
    plt.ylabel('ABC')
    
plt.savefig('/home/marcel/fig1_v09.pdf', bbox_inches='tight')
plt.show()