In [1]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.summarystats as ds

from delfi.simulator.BaseSimulator import BaseSimulator

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
from delfi.utils.viz import plot_pdf

from tqdm import tqdm_notebook as tqdm

from parameters import ParameterSet

%matplotlib notebook

In [2]:
class ShapeModel(BaseSimulator):
    def __init__(self, f, sigma, ndim, seed=None):
        super().__init__(dim_param=ndim, seed=seed)
        self.f = f
        
        self.sigma = sigma
        
    def gen_single(self, params):
        m = self.f(params)
        draw = self.rng.normal(loc=m, scale=self.sigma, size=(1,))
        return { 'data' : draw }
    
class ShapeDistribution(dd.BaseDistribution.BaseDistribution):
    def __init__(self, f, sm, obs_stats, seed=None):
        super().__init__(sm.dim_param, seed=seed)
        self.f = f
        self.sm = sm
        self.obs_stats = obs_stats
        self.m = np.linalg.norm(obs_stats)
        self.sigma = sm.sigma
        
    def eval(self, params, log=True):
        logl = -0.5 * np.log(2 * np.pi * self.sigma) -0.5 * ([self.f(param) for param in params] - self.m) ** 2 / (self.sigma ** 2)
        return logl if log else np.exp(logl)

In [75]:
def plot_pdf_2d(params, dist, label = None):
    lims = 2 * np.linalg.norm(params.obs_stats)
    xlist = np.linspace(-lims, lims, 80)
    ylist = np.linspace(-lims, lims, 80)
    
    fig = plt.figure(figsize=(8,4))
    ax = fig.add_subplot(111, projection='3d')
    
    X, Y = np.meshgrid(xlist, ylist)
    xys = np.asarray((X.ravel(), Y.ravel())).T
    
    if label is not None:
        fig.suptitle(label)
        
    Z = dist.eval(xys, log=False).reshape(X.shape)
    ax.plot_surface(X, Y, Z, cmap=cm.plasma, rstride=1, cstride=1, linewidth=0, antialiased=False)
        
    ax.set_xlabel("$\\theta_1$")
    ax.set_ylabel("$\\theta_2$")
    ax.set_zlabel("$p(\\theta)$")
    plt.show()
    
    return fig

In [76]:
def run_SNPE(params):
    m = ShapeModel(params.f, sigma=params.sigma, ndim=params.ndim, seed=params.seed)
    p = dd.Uniform(-params.prior_width * np.ones(params.ndim), params.prior_width * np.ones(params.ndim), seed=params.seed)
    s = ds.Identity(1)
    g = dg.Default(model=m, prior=p, summary=s)

    gt = ShapeDistribution(params.f, m, params.obs_stats, seed=params.seed)
    
    res = infer.SNPE(g, obs=[params.obs_stats], 
                     n_hiddens=params.n_hiddens, 
                     n_components=params.n_components, 
                     seed=params.seed, 
                     prior_mixin=params.prior_mixin, 
                     verbose=True)
    logs, tds, posteriors = res.run(n_train=params.n_train, n_rounds = params.n_rounds, round_cl=3)

    return res, posteriors, gt

In [87]:
default_params = ParameterSet({})

default_params.seed = 394
default_params.ndim = 2
default_params.n_components = 10
default_params.n_hiddens = [ default_params.n_components * default_params.ndim * 5 ]

default_params.sigma = 0.2

default_params.true_params = 2 * np.eye(default_params.ndim)[0]
default_params.obs_stats = [np.linalg.norm(default_params.true_params)]
default_params.prior_width = 2

default_params.prior_mixin = 0.0


default_params.n_train = 1500
default_params.n_rounds = 3

ushape = lambda x: x[0] ** 2 - x[1]
ring = np.linalg.norm
default_params.f = ushape

In [88]:
std_params = default_params

res, posteriors, gt = run_SNPE(std_params)


































In [89]:
plot_pdf_2d(std_params, gt, "Ground truth");

for i in range(len(posteriors)):
    p = posteriors[i]
    plot_pdf_2d(std_params, p, "Round {}".format(i+1));



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [90]:
prior_mixin_params = ParameterSet(default_params)
prior_mixin_params.prior_mixin = 0.1

res, posteriors, gt = run_SNPE(prior_mixin_params)


































In [91]:
plot_pdf_2d(prior_mixin_params, gt, "Ground truth");

for i in range(len(posteriors)):
    p = posteriors[i]
    plot_pdf_2d(prior_mixin_params, p, "Round {}".format(i+1));



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [92]:
one_round_params = ParameterSet(default_params)
one_round_params.n_train *= one_round_params.n_rounds
one_round_params.n_rounds = 1

res, posteriors, gt = run_SNPE(one_round_params)
















In [93]:
plot_pdf_2d(one_round_params, gt, "Ground truth");

for i in range(len(posteriors)):
    p = posteriors[i]
    plot_pdf_2d(one_round_params, p, "Round {}".format(i+1));



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [94]:
wide_prior_params = ParameterSet(default_params)
wide_prior_params.prior_width *= 5

res, posteriors, gt = run_SNPE(wide_prior_params)


































In [95]:
plot_pdf_2d(wide_prior_params, gt, "Ground truth");

for i in range(len(posteriors)):
    p = posteriors[i]
    plot_pdf_2d(wide_prior_params, p, "Round {}".format(i+1));



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [96]:
wide_prior_prior_mixin_params = ParameterSet(wide_prior_params)
wide_prior_prior_mixin_params.prior_mixin = 0.1

res, posteriors, gt = run_SNPE(wide_prior_prior_mixin_params)


































In [97]:
plot_pdf_2d(wide_prior_prior_mixin_params, gt, "Ground truth");

for i in range(len(posteriors)):
    p = posteriors[i]
    plot_pdf_2d(wide_prior_prior_mixin_params, p, "Round {}".format(i+1));



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [98]:
wide_prior_one_round_params = ParameterSet(one_round_params)
wide_prior_one_round_params.prior_width = wide_prior_params.prior_width

res, posteriors, gt = run_SNPE(wide_prior_one_round_params)
















In [99]:
plot_pdf_2d(wide_prior_one_round_params, gt, "Ground truth");

for i in range(len(posteriors)):
    p = posteriors[i]
    plot_pdf_2d(wide_prior_one_round_params, p, "Round {}".format(i+1));



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [100]:
def plot_pdfs_1d(dists, labels = None):
    lims = 2
    xlist = np.linspace(-lims, lims, 200)
    
    fig, ax = plt.subplots(1, figsize=(12, 3))
    
    for i in range(len(dists)):
        d = dists[i]
        ylist = d.eval(xlist.reshape((-1,1)), log=False)
        if labels:
            ax.plot(xlist, ylist, label=labels[i])
        else:
            ax.plot(xlist, ylist)
    
    box = ax.get_position()
    ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
    if labels:
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        
    ax.set_xlabel("$\\theta$")
    ax.set_ylabel("$p(\\theta)$")
    plt.show()
    
    return fig

In [34]:
#plot_pdfs_1d([gt, p] + posteriors, ["Ground truth", "Prior"] + [ "Round {}".format(i + 1) for i in range(n_rounds) ]);