In [None]:
%%capture
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt
import timeit
import theano
import pygpu
import imageio

from lfimodels.snl_exps.util import save_results, load_results

import delfi.distribution as dd
import delfi.inference as infer
import delfi.generator as dg

from delfi.simulator import TwoMoons
import delfi.summarystats as ds
from delfi.utils.viz import plot_pdf, probs2contours

seed=46
return_abs = False

# analytic target posterior for plotting (y-axis flipped for overlay with imshow figures !)
p_true = dd.MoG(a=[0.5, 0.5], 
                ms=[np.asarray([.5, -.5]), np.asarray([-.5, .5])], 
                Ss=[0.01*np.eye(2), 0.01*np.eye(2)])
p_true.ndim=2

# very basic approach to controlling generator seeds
def init_g(seed):
    m = TwoMoons(mean_radius=0.1, sd_radius=0.01, baseoffset=0.25,
                 #angle= np.pi/4.0,  # rotation angle in radians
                 #mapfunc=lambda theta, p: p + theta,  # transforms noise dist.
                 seed=seed)
    p = dd.Uniform(lower=[-1,-1], upper=[1,1], seed=seed)
    s = ds.Identity()
    return dg.Default(model=m, prior=p, summary=s)

g = init_g(seed=seed)

obs_stats = np.array([[0., 0.]])

verbose =True


exp_id = 'seed' + str(seed)
save_path = 'results/two_moons_runs/'    

In [None]:
# MAF SNPE-C
log, tds, posteriors, setup_dict = load_results(exp_id='seed46', path=save_path + '_discrete_MAF_SNPEC')

In [None]:
xo = 1.*obs_stats.flatten()
lims = np.array([[-0.5,0.5], [-0.5,0.5]])
i,j,resolution = 0,1,100
xx = np.linspace(lims[i, 0], lims[i, 1], resolution)
yy = np.linspace(lims[j, 0], lims[j, 1], resolution)
X, Y = np.meshgrid(xx, yy)
xy = np.concatenate(
    [X.reshape([-1, 1]), Y.reshape([-1, 1])], axis=1)

P=[]
for idx in range(10):
    posterior = posteriors[idx]
    P.append(posterior.eval(xy, log=False).reshape(list(X.shape)))

In [None]:
fig = plt.figure()
I = []
for idx in range(10):
    pp = P[idx]
    plt.imshow(pp.T, origin='lower',
           extent=[lims[j, 0], lims[j, 1], lims[i, 0], lims[i, 1]],
           aspect='auto', interpolation='none')
    plt.axis('equal')
    plt.axis('off')
    plt.box('off')
    fig.canvas.draw()
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    I.append(image.reshape(fig.canvas.get_width_height()[::-1] + (3,)))

In [None]:
kwargs_write = {'fps':2.0, 'quantizer':'nq'}
imageio.mimsave('./twomoons_mafapt.gif', I, fps=2.0)
