In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

from lfimodels.snl_exps.util import save_results, load_results

seed=46
return_abs = False

# analytic target posterior for plotting (y-axis flipped for overlay with imshow figures !)
p_true = dd.MoG(a=[0.5, 0.5], 
                ms=[np.asarray([.5, -.5]), np.asarray([-.5, .5])], 
                Ss=[0.01*np.eye(2), 0.01*np.eye(2)])
p_true.ndim=2

# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 #angle= np.pi/4.0,  # rotation angle in radians
                 #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[0., 0.]])

verbose =True

trn_data = g.gen(1000)
plt.figure(figsize=(6,5))
plt.plot(trn_data[1], trn_data[0], '.')
plt.xlabel('x')
plt.ylabel('theta(1) resp. theta(2)')
plt.title('marginals p(theta_i, x)')
plt.show()

exp_id = 'seed' + str(seed)
save_path = 'results/two_moons_runs/'    


In [None]:
# MAF SNPE-C
log, tds, posteriors, _ = load_results(exp_id=exp_id, path=save_path + '_discrete_MAF_SNPEC')

# SNL
_, tds_snl, posteriors_snl, _ = load_results(exp_id=exp_id, path=save_path + '_MAF_SNL')

# SNPE-A
logs_A, tds_A, posteriors_A,_ = load_results(exp_id=exp_id, path=save_path + '_MDN_SNPEA')

# SNPE-B
logs_B, tds_B, posteriors_B, _ = load_results(exp_id=exp_id, path=save_path + '_MDN_SNPEB')

# Gaussian-proposal SNPE-C
logs_gC, tds_gC, posteriors_gC, _ = load_results(exp_id=exp_id, path=save_path + '_continuous_MDN_SNPEC')

# rej. ABC
p_all = np.load(save_path+ 'ABC.npy')

# assemble figure

In [None]:
from snl.pdfs import gaussian_kde

plt.figure(figsize=(9, 12), frameon=False)
plt.subplot(2,3,1)


xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

"""
plt.subplot(7, 3, 3)
pp = p_true.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
plt.ylabel('real posterior')
plt.xticks([])
plt.yticks([])
"""

idx_r = [0,4,9]

for r in range(len(idx_r)):
    plt.subplot(5, 4, 1 +r)
    try:
        posterior = posteriors_A[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-A')
    plt.xticks([])
    plt.yticks([])
    plt.title('N = ' + str( (idx_r[r]+1)*setup_dict['n_train']))



for r in range(len(idx_r)):
    plt.subplot(5, 4, 5 +r)
    try:
        posterior = posteriors_B[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-B')
    plt.xticks([])
    plt.yticks([])


for r in range(len(idx_r)):
    plt.subplot(5, 4, 9 +r)
    try:
        posterior = posteriors_gC[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('APT (MDN)')
    plt.xticks([])
    plt.yticks([])
    

"""
for r in range(len(idx_r)):
    plt.subplot(7, 4, 17 +r)
    try:
        posterior = posteriors_dC[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('discrete-trained MoG SNPE-C')
    plt.xticks([])
    plt.yticks([])
"""

for r in range(len(idx_r)):
    plt.subplot(5, 4, 13 +r)
    try:
        posterior = posteriors[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass        
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('APT (MAF)')
    plt.xticks([])
    plt.yticks([])
    
for r in range(len(idx_r)):
    plt.subplot(5, 4, 17 +r)
    if idx_r[r]+1 < len(inf.all_ps) :
        plt.plot(inf.all_ps[idx_r[r]+1][:,0],
                 inf.all_ps[idx_r[r]+1][:,1], 'k.')
    else:     
        plt.plot(ps[:,0],
                 ps[:,1], 'k.')
    plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])
        
    if r == 0:
        plt.ylabel('SNL')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    plt.xticks([])
    plt.yticks([])
    

    plt.subplot(5,4,12)
    i,j = 0,1
    kde = gaussian_kde(xs=p_all, std=0.01)
    kde = dd.MoG(xs = [dd.Gaussian(m = kde.xs[i].m, S=kde.xs[i].S) for i in range(p_all.shape[0])], 
                 a=1./p_all.shape[0] * np.ones(p_all.shape[0]))
    pp = kde.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
    plt.axis([lims[0][0], lims[0][1], lims[1][0], lims[1][1]])    
    plt.xticks([])
    plt.yticks([])
    plt.title('N = 1e6')
    plt.ylabel('ABC')
    
plt.savefig('/home/marcel/fig1_v09.pdf', bbox_inches='tight')
plt.show()