In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

seed=42
return_abs = False

# analytic target posterior for plotting (y-axis flipped for overlay with imshow figures !)
p_true = dd.MoG(a=[0.5, 0.5], 
                ms=[np.asarray([.5, -.5]), np.asarray([-.5, .5])], 
                Ss=[0.01*np.eye(2), 0.01*np.eye(2)])
p_true.ndim=2


# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 angle= np.pi/4.0,  # rotation angle in radians
                 mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[0., -0.]])

trn_data = g.gen(1000)
plt.figure(figsize=(6,5))
plt.plot(trn_data[1], trn_data[0], '.')
plt.xlabel('x')
plt.ylabel('theta(1) resp. theta(2)')
plt.title('marginals p(theta_i, x)')
plt.show()
        
# training schedule
n_train=1000
n_rounds=10

# fitting setup
minibatch=100
epochs=[1000//(r+1) for r in range(n_rounds)]

# network setup
reg_lambda=0.001

# convenience
pilot_samples=0
svi=True
verbose=True
prior_norm=False
init_norm=False

# SNPE-C parameters
n_null = 10

# MAF parameters
n_hiddens_maf=[50,50]
mode='random' # ordering of variables for MADEs
n_mades = 5 # number of MADES
act_fun = 'tanh'
batch_norm = False # batch-normalization currently not supported
train_on_all = True

# MDN parameters
n_hiddens_mdn=[50,50]
n_components = 8


# discrete-proposal SNPE-C (MAF)

In [None]:
# control MAF seed
rng = np.random
rng.seed(seed)

# generator
g = init_g(seed=seed)

# inference object
res = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_maf, 
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 n_mades=n_mades, # providing this argument triggers usage of MAFs (vs. MDNs)
                 act_fun=act_fun,
                 mode=mode, 
                 rng=rng, 
                 batch_norm=batch_norm, 
                 verbose=verbose,
                 prior_norm=prior_norm)

# train

t = timeit.time.time()

logs, tds, posteriors = res.run(n_train=n_train, 
                    proposal='discrete',
                    moo='resample',
                    n_null = n_null,
                    n_rounds=n_rounds,
                    train_on_all=train_on_all,
                    minibatch=minibatch, 
                    epochs=epochs)

print(timeit.time.time() -  t)


In [None]:
for r in range(n_rounds):
    plt.plot(logs[r]['loss'])
    plt.show()

In [None]:
from delfi.utils.viz import probs2contours

# plot loss curves
plt.figure(figsize=(12, 8))
for r in range(n_rounds):
    plt.subplot(np.ceil(n_rounds/3), 3, r+1)
    plt.plot(logs[r]['loss'])
    plt.title('loss for round r=' + str(r+1))
plt.show()


# plot posterior estimates (overlaid with ground-truth)
xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
plt.figure(figsize=(16, 6))
plt.subplot(np.ceil(n_rounds/3), 3, 1)
pp = g.prior.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(0*pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
pp = p_true.eval(xy, log=False).reshape(list(X.shape))
#plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
#plt.title('prior (real posterior overlaid)')
for r in range(n_rounds):
    plt.subplot(np.ceil(n_rounds/3), 3, r + 1)
    posterior = posteriors[r] 
    pp = posterior.eval(xy, log=False).reshape(list(X.shape))
    plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
    #pp = p_true.eval(xy, log=False).reshape(list(X.shape))
    #plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    plt.title('posterior estimate after round r='+str(r+1))
    #plt.savefig('/home/mackelab/Desktop/SNPE_10rounds_posterior_est.pdf')

plt.show()

# SNL (MAF)

In [None]:
import sys
import snl.inference.nde as nde
from snl.ml.models.mafs import ConditionalMaskedAutoregressiveFlow
from delfi.utils.delfi2snl import SNLprior, SNLmodel


# control MAF seed
rng = np.random
rng.seed(seed)

# explicit call to MAF constructor
theta, x = g.gen(1)
n_inputs, n_outputs  = x.size, theta.size
model = ConditionalMaskedAutoregressiveFlow(
                n_inputs=n_inputs,
                n_outputs=n_outputs,
                n_hiddens=n_hiddens_maf,
                act_fun=act_fun,
                n_mades=n_mades,
                mode=mode,
                rng=rng)


# generator
g = init_g(seed=seed)

# inference object
inf = nde.SequentialNeuralLikelihood(SNLprior(g.prior),               # method to draw parameters  
                                     SNLmodel(g.model, g.summary).gen # method to draw summary stats
                                    )

# train
t = timeit.time.time()

rng = np.random # control  
rng.seed(seed)  # MCMC seed
model = inf.learn_likelihood(obs_stats.flatten(), model, n_samples=n_train, n_rounds=n_rounds, 
                             train_on_all=True, thin=10, save_models=False, 
                             logger=sys.stdout, rng=rng)

print(timeit.time.time() -  t)


### visualize learned likelihood

In [None]:
import snl.inference.mcmc as mcmc
log_posterior = lambda t: model.eval([t, obs_stats.flatten()]) + inf.prior.eval(t)
sampler = mcmc.SliceSampler(x=inf.all_ps[-1][-1], lp_f=log_posterior, thin=10)
ps = sampler.gen(1000)

In [None]:
xo = 1.*obs_stats.flatten()

lims = np.array([[-1,1], [-1,1]])
i,j,resolution = 0,1, 100
        
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)
pp = model.eval((xo, xy), log=False).reshape(list(X.shape))

plt.imshow(pp.T, origin='lower',
           extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
           aspect='auto', interpolation='none')
#plt.savefig('/home/mackelab/Desktop/SNL_5rounds_final_likelihood_tightll.pdf')
plt.show()

In [None]:
plt.figure(figsize=(16,11))
for r in range(n_rounds):
    plt.subplot(np.ceil(n_rounds/3+1), 3, r + 1)
    plt.plot(inf.all_ps[r][:,0],
             inf.all_ps[r][:,1], 'k.')
    plt.axis([-1,1,-1,1])
    plt.xlabel('theta1')
    plt.xlabel('theta2')
    plt.title('round r='+str(r))

plt.subplot(np.ceil(n_rounds/3+1), 3, n_rounds+1)
plt.plot(ps[:,0],
         ps[:,1], 'k.')
plt.axis([-1,1,-1,1])
plt.xlabel('theta1')
plt.xlabel('theta2')
plt.title('round r='+str(n_rounds))
    
plt.show()

# SNPE A (always MDN)

In [None]:
# generator
g = init_g(seed=seed)
g.prior = dd.Uniform(lower=[-2,-2], upper=[2,2])

n_train = 1000
n_rounds = 10

# inference object
res_A = infer.CDELFI(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_mdn, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 init_norm=init_norm,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_A, tds_A, posteriors_A = res_A.run(n_train=n_train, 
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=np.max(epochs))

print(timeit.time.time() -  t)


In [None]:

xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

for r in range(10):

    posterior = posteriors_A[r]
    if posterior is None:
        pass
    else:
        plt.figure(figsize=(8, 8))
    
        #posterior = posteriors_A[r] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')

        else:
            plt.text(-0.1, 0., 'broke', color='w')
            plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    

        plt.ylabel('SNPE-A')
        plt.xticks([])
        plt.yticks([])
        plt.title('N = ' + str( (r+1)*n_train ))
        plt.show()
        

# SNPE B (always MDN)

In [None]:
# generator
g = init_g(seed=seed)

# inference object
res_B = infer.SNPE(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_mdn, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,                 
                 verbose=verbose,
                 init_norm=init_norm,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_B, tds_B, posteriors_B = res_B.run(
                    n_train=n_train, 
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    round_cl=1, 
                    epochs=np.max(epochs))

print(timeit.time.time() -  t)


In [None]:

xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

for r in range(10):

    posterior = posteriors_B[r]
    if posterior is None:
        pass
    else:
        plt.figure(figsize=(8, 8))
    
        #posterior = posteriors_A[r] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')

        else:
            plt.text(-0.1, 0., 'broke', color='w')
            plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')    

        plt.ylabel('SNPE-B')
        plt.xticks([])
        plt.yticks([])
        plt.title('N = ' + str( (r+1)*n_train ))
        plt.show()
        

# discrete-proposal SNPE C (MDN)

In [None]:
# generator
g = init_g(seed=seed)

# inference object
res_dC = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_mdn, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_dC, tds_dC, posteriors_dC = res_dC.run(n_train=n_train, 
                    proposal='discrete',
                    moo='resample',
                    n_null = n_null,
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=np.max(epochs))

print(timeit.time.time() -  t)


# Gaussian-proposal SNPE-C (MDN)

In [None]:
# generator
g = init_g(seed=seed)

# inference object
res_gC = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_mdn, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_gC, tds_gC, posteriors_gC = res_gC.run(n_train=n_train, 
                    proposal='gaussian',
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=np.max(epochs))

print(timeit.time.time() -  t)


# MoG-proposal SNPE-C (MDN)

In [None]:
# generator
g = init_g(seed=seed)

# inference object
res_mC = infer.SNPEC(g, 
                 obs=obs_stats, 
                 n_hiddens=n_hiddens_mdn, 
                 n_components=n_components,
                 seed=seed, 
                 reg_lambda=reg_lambda,
                 pilot_samples=pilot_samples,
                 svi=svi,
                 verbose=verbose,
                 prior_norm=prior_norm)

# train
t = timeit.time.time()

logs_mC, tds_mC, posteriors_mC = res_mC.run(n_train=n_train, 
                    proposal='mog',
                    n_rounds=n_rounds, 
                    minibatch=minibatch, 
                    epochs=np.max(epochs))

print(timeit.time.time() -  t)


## direct comparison SNPE-C against SNPE-A (both MDN)

In [None]:
for r in range(n_rounds):
    k = 1
    posterior_A = posteriors_A[-1]
    while posterior_A is None:        
        k+= 1
        posterior_A = posteriors_A[-k]
        
    fig,_ = plot_pdf(posterior_A,
             samples=posteriors_C[r].gen(1000).T,
             lims=[[-1,1],[-1,1]],
             ticks=True,             
             #pdf2=g.prior,
             #gt=np.array([5,-5]), 
             resolution=100,
             figsize=(16,16));
    fig.suptitle('final posterior estimate vs MCMC samples and prior', fontsize=14)
    fig.show()


In [None]:
for r in range(n_rounds):
    k = 1
    posterior_A = posteriors_A[-1]
    while posterior_A is None:        
        k+= 1
        posterior_A = posteriors_A[-k]
        
    fig,_ = plot_pdf(posterior_A,
             pdf2=posteriors_gC[r],
             lims=[[-1,1],[-1,1]],
             ticks=True,             
             #pdf2=g.prior,
             #gt=np.array([5,-5]), 
             resolution=100,
             figsize=(16,16));
    fig.suptitle('final posterior estimate vs MCMC samples and prior', fontsize=14)
    fig.show()


# assemble figure

In [None]:
plt.figure(figsize=(11, 16))
plt.subplot(2,3,1)


xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

"""
plt.subplot(7, 3, 3)
pp = p_true.eval(xy, log=False).reshape(list(X.shape))
plt.imshow(pp.T, origin='lower',
               extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
               aspect='auto', interpolation='none')
plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
plt.ylabel('real posterior')
plt.xticks([])
plt.yticks([])
"""

idx_r = [0,3,6,8]

for r in range(len(idx_r)):
    plt.subplot(7, 4, 5 +r)
    try:
        posterior = posteriors_A[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-A')
    plt.xticks([])
    plt.yticks([])
    plt.title('N = ' + str( (idx_r[r]+1)*n_train ))



for r in range(len(idx_r)):
    plt.subplot(7, 4, 9 +r)
    try:
        posterior = posteriors_B[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('SNPE-B')
    plt.xticks([])
    plt.yticks([])


for r in range(len(idx_r)):
    plt.subplot(7, 4, 13 +r)
    try:
        posterior = posteriors_gC[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('Gaussian SNPE-C')
    plt.xticks([])
    plt.yticks([])
    


for r in range(len(idx_r)):
    plt.subplot(7, 4, 17 +r)
    try:
        posterior = posteriors_dC[idx_r[r]] 
        if not posterior is None:
            pp = posterior.eval(xy, log=False).reshape(list(X.shape))
            plt.imshow(pp.T, origin='lower',
                       extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                       aspect='auto', interpolation='none')
    except:
        plt.imshow(np.zeros((resolution,resolution)), origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
        plt.text(-0.1, 0., 'broke', color='w')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('discrete-trained MoG SNPE-C')
    plt.xticks([])
    plt.yticks([])
    
for r in range(len(idx_r)):
    plt.subplot(7, 4, 21 +r)
    try:
        posterior = posteriors[idx_r[r]] 
        pp = posterior.eval(xy, log=False).reshape(list(X.shape))
        plt.imshow(pp.T, origin='lower',
                   extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
                   aspect='auto', interpolation='none')
    except:
        pass        
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    if r == 0:
        plt.ylabel('MAF SNPE-C')
    plt.xticks([])
    plt.yticks([])
    
for r in range(len(idx_r)):
    plt.subplot(7, 4, 25 +r)
    if idx_r[r]+1 < len(inf.all_ps) :
        plt.plot(inf.all_ps[idx_r[r]+1][:,0],
                 inf.all_ps[idx_r[r]+1][:,1], 'k.')
    else:     
        plt.plot(ps[:,0],
                 ps[:,1], 'k.')
    plt.axis([-1,1,-1,1])  
        
    if r == 0:
        plt.ylabel('SNL')
    if r == -2:
        pp = p_true.eval(xy, log=False).reshape(list(X.shape))
        plt.contour(Y, X, probs2contours(pp, levels=(0.68, 0.95)), levels=(0.68, 0.95), colors=('w', 'y'))
    plt.xticks([])
    plt.yticks([])
    

plt.savefig('/home/marcel/Desktop/fig1_v05.pdf')
plt.show()

In [None]:
for r in range(3):
    plt.plot(np.log(posteriors_dC[r].a))
plt.legend(['round 1', 'round 2', 'round 3'])
plt.xlabel('i')
plt.ylabel('alpha[i]')
plt.title('MoG mixture coefficients MDN SNPE')
#plt.savefig('fig1_MDN_SNPEs_getting_stuck_across_rounds.pdf')
plt.show()