# figure 1 (two Moons)
- loads model fits as produced by twoMoons_allAlgs_fits.ipynb
- (original fits used for the paper produced with 'playground' version of twoMoons notebook)

In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results

seed=46
return_abs = False

# analytic target posterior for plotting (y-axis flipped for overlay with imshow figures !)
p_true = dd.MoG(a=[0.5, 0.5], 
                ms=[np.asarray([.5, -.5]), np.asarray([-.5, .5])], 
                Ss=[0.01*np.eye(2), 0.01*np.eye(2)])
p_true.ndim=2

# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 #angle= np.pi/4.0,  # rotation angle in radians
                 #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[0., 0.]])

verbose =True

trn_data = g.gen(1000)
plt.figure(figsize=(6,5))
plt.plot(trn_data[1], trn_data[0], '.')
plt.xlabel('x')
plt.ylabel('theta(1) resp. theta(2)')
plt.title('marginals p(theta_i, x)')
plt.show()

exp_id = 'seed' + str(seed)
save_path = 'results/two_moons_runs/'    


In [None]:
# MAF SNPE-C
log, tds, posteriors, setup_dict = load_results(exp_id='seed46', path=save_path + '_discrete_MAF_SNPEC')

# SNL
_, tds_snl, posteriors_snl, _ = load_results(exp_id='seed42', path=save_path + '_MAF_SNL')
ps_snl = np.load(save_path+ 'SNL_final_ps.npy' )

# SNPE-A
logs_A, tds_A, posteriors_A,_ = load_results(exp_id='seed42', path=save_path + '_MDN_SNPEA')

# SNPE-B
logs_B, tds_B, posteriors_B, _ = load_results(exp_id='seed42', path=save_path + '_MDN_SNPEB')

# Gaussian-proposal SNPE-C
logs_gC, tds_gC, posteriors_gC, _ = load_results(exp_id='seed43', path=save_path + '_continuous_MDN_SNPEC')

# rej. ABC
p_all = np.load(save_path+ 'ABC.npy')

# true samples
samples_gt = np.load('results/two_moons_runs/5k_gtposterior_2moons.npy')

# SMC-ABC
ps_SMC_100 = np.load(save_path+ 'SMC_eps_1_00.npy')
ps_SMC_010 = np.load(save_path+ 'SMC_eps_0_10.npy')
ps_SMC_001 = np.load(save_path+ 'SMC_eps_0_01.npy')

setup_dict = setup_dict[()]

# assemble figure

In [None]:
from snl.pdfs import gaussian_kde

plt.figure(figsize=(5, 12), frameon=False)
plt.subplot(2,3,1)


xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

"""
plt.subplot(7, 3, 3)
pp = p_true.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
plt.ylabel('real posterior')
plt.xticks([])
plt.yticks([])
"""

n_rows, n_cols = 7,3
idx_r = [0,4,9]

for r in range(len(idx_r)):
    plt.subplot(n_rows,n_cols,4 +r)
    try:
        posterior = posteriors_A[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-A')
    plt.xticks([])
    plt.yticks([])
    plt.title('N = ' + str( (idx_r[r]+1)*setup_dict['n_train']))
    plt.axis('equal')
    plt.box('off')


for r in range(len(idx_r)):
    plt.subplot(n_rows,n_cols,7 +r)
    try:
        posterior = posteriors_B[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-B')
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
    plt.box('off')

for r in range(len(idx_r)):
    plt.subplot(n_rows,n_cols,16 +r)
    try:
        posterior = posteriors_gC[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('APT (MDN)', fontweight='bold')
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
    plt.box('off')    

"""
for r in range(len(idx_r)):
    plt.subplot(7, 4, 17 +r)
    try:
        posterior = posteriors_dC[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('discrete-trained MoG SNPE-C')
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
    plt.box('off')    
"""

for r in range(len(idx_r)):
    plt.subplot(n_rows,n_cols,19 +r)
    try:
        posterior = posteriors[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass        
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('APT (MAF)', fontweight='bold')
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
    plt.box('off')
    
for r in range(len(idx_r)):
    plt.subplot(n_rows,n_cols, 10 +r)
    
    """
    if idx_r[r]+1 < len(tds_snl[0]) :
        plt.plot(np.asarray(tds_snl[0][idx_r[r]+1])[:,0],
                 np.asarray(tds_snl[0][idx_r[r]+1])[:,1], 'k.')
    else:     
        plt.plot(ps[:,0],
                 ps[:,1], 'k.')
    """

    if idx_r[r]+1 < len(tds_snl[0]) :
        ps = np.asarray(tds_snl[0][idx_r[r]+1])
    else:     
        ps = ps_snl      
    kde = gaussian_kde(xs=ps, std=0.01)
    kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(p_all.shape[0])], 
                 a=1./p_all.shape[0] * np.ones(p_all.shape[0]))
    pp = kde.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
    
    plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])
        
    if r == 0:
        plt.ylabel('SNL')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
    plt.xlim(lims[0,0], lims[0,1])
    plt.ylim(lims[1,0], lims[1,1])
    plt.box('off')

ps_smc = [ps_SMC_100, ps_SMC_010, ps_SMC_001]
Ns = [r'N$\approx$1000', r'N$\approx$1e5', r'N$\approx$5e6']
eps_lvls = [r'$\epsilon=1$',r'$\epsilon=0.1$',r'$\epsilon=0.01$']
for r in range(len(idx_r)):
    plt.subplot(n_rows,n_cols, 13 +r)

    kde = gaussian_kde(xs=ps_smc[r], std=0.01)
    kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(p_all.shape[0])], 
                 a=1./p_all.shape[0] * np.ones(p_all.shape[0]))
    pp = kde.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
    plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])
    plt.text(-0.02 - (r==0) * 0.04, -0.4, Ns[r], color='w')
    plt.text(-0.02 + (2-r)*0.11, -0.25, eps_lvls[r], color='w')
    if r == 0:
        plt.ylabel('SMC')
    plt.xticks([])
    plt.yticks([])    
    plt.axis('equal')
    plt.xlim(lims[0,0], lims[0,1])
    plt.ylim(lims[1,0], lims[1,1])
    plt.box('off')    
    
    



plt.subplot(n_rows,3,3)
i,j = 0,1
kde = gaussian_kde(xs=samples_gt, std=0.01)
kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(samples_gt.shape[0])], 
             a=1./samples_gt.shape[0] * np.ones(samples_gt.shape[0]))
pp = kde.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(pp.T, origin='lower',
           extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
           aspect='auto', interpolation='none')
plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])    
plt.xticks([])
plt.yticks([])
plt.title('true posterior')
plt.axis('equal')
plt.box('off')    
    
plt.savefig('/home/marcel/code/lfi_experiments/snpec/results/figs/fig1.pdf', bbox_inches='tight')
plt.show()

# rejection-corrected SMC

In [None]:
from snl.inference.abc import SMC, calc_dist
import snl.util as util

class rejSMC(SMC):
    
    def sample_initial_population(self, obs_data, n_particles, eps, logger, rng):
        """
        Sample an initial population of n_particles, with tolerance eps.
        """

        ps = []
        n_sims = 0

        for i in range(n_particles):

            dist = float('inf')
            prop_ps = None

            while dist > eps:
                while True:
                    prop_ps = self.prior.gen(rng=rng)
                    try:
                        self.prior.eval(prop_ps, log=True)
                        break
                    except:
                        pass
                data = self.sim_model(prop_ps, rng=rng)
                dist = calc_dist(data, obs_data)
                n_sims += 1

            ps.append(prop_ps)

            logger.write('particle {0}\n'.format(i + 1))

        return np.array(ps), n_sims

    def sample_next_population(self, ps, log_weights, obs_data, eps, logger, rng):
        """
        Samples a new population of particles by perturbing an existing one. Uses a gaussian perturbation kernel.
        """

        n_particles, n_dim = ps.shape
        n_sims = 0
        weights = np.exp(log_weights)

        # calculate population covariance
        mean = np.mean(ps, axis=0)
        cov = 2.0 * (np.dot(ps.T, ps) / n_particles - np.outer(mean, mean))
        std = np.linalg.cholesky(cov)

        new_ps = np.empty_like(ps)
        new_log_weights = np.empty_like(log_weights)

        for i in range(n_particles):

            dist = float('inf')

            while dist > eps:
                while True:
                    idx = util.math.discrete_sample(weights, rng=rng)
                    new_ps[i] = ps[idx] + np.dot(std, rng.randn(n_dim))
                    try:
                        self.prior.eval(new_ps[i], log=True)
                        break
                    except:
                        pass                
                data = self.sim_model(new_ps[i], rng=rng)
                dist = calc_dist(data, obs_data)
                n_sims += 1

            # calculate unnormalized weights
            log_kernel = -0.5 * np.sum(scipy.linalg.solve_triangular(std, (new_ps[i] - ps).T, lower=True) ** 2, axis=0)
            new_log_weights[i] = self.prior.eval(new_ps[i], log=True) - scipy.misc.logsumexp(log_weights + log_kernel)

            logger.write('particle {0}\n'.format(i + 1))

        # normalize weights
        new_log_weights -= scipy.misc.logsumexp(new_log_weights)

        return new_ps, new_log_weights, n_sims    

In [None]:
from snl.inference.abc import SMC
from delfi.utils.delfi2snl import SNLmodel, SNLprior
import scipy 

g = init_g(seed=seed)
sampler = rejSMC(SNLprior(dd.Uniform(lower=[-1,-1], upper=[1,1])), SNLmodel(g.model, g.summary).gen)
ps, log_weights = sampler.run(obs_data=obs_stats.flatten(), n_particles=1000, eps_init=1, eps_last=1, eps_decay=0.9)
np.save(save_path+ 'SMC_eps_1_00', ps)
plt.plot(ps[:,0], ps[:,1], '.')
plt.show()

g = init_g(seed=seed)
sampler = rejSMC(SNLprior(dd.Uniform(lower=[-1,-1], upper=[1,1])), SNLmodel(g.model, g.summary).gen)
ps, log_weights = sampler.run(obs_data=obs_stats.flatten(), n_particles=1000, eps_init=1., eps_last=0.1, eps_decay=0.9)
np.save(save_path+ 'SMC_eps_0_10', ps)
plt.plot(ps[:,0], ps[:,1], '.')
plt.show()


g = init_g(seed=seed)
sampler = rejSMC(SNLprior(dd.Uniform(lower=[-1,-1], upper=[1,1])), SNLmodel(g.model, g.summary).gen)
ps, log_weights = sampler.run(obs_data=obs_stats.flatten(), n_particles=1000, eps_init=1, eps_last=0.01, eps_decay=0.9)
np.save(save_path+ 'SMC_eps_0_01', ps)
plt.plot(ps[:,0], ps[:,1], '.')
plt.show()