In [2]:
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.summarystats as ds
import delfi.kernel as dk

from delfi.simulator.BaseSimulator import BaseSimulator

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np
from delfi.utils.viz import plot_pdf

from tqdm import tqdm_notebook as tqdm

from parameters import ParameterSet

%matplotlib notebook

from delfi.utils.viz import plot_pdf, plot_hist_marginals

In [3]:
class ShapeModel(BaseSimulator):
    def __init__(self, f, sigma, ndim, seed=None):
        super().__init__(dim_param=ndim, seed=seed)
        self.f = f
        
        self.sigma = sigma
        
    def gen_single(self, params):
        m = self.f(params)
        draw = self.rng.multivariate_normal(m, self.sigma)
        return { 'data' : draw }
    
class ShapeDistribution(dd.BaseDistribution.BaseDistribution):
    def __init__(self, f, params, sm, obs_stats, seed=None):
        super().__init__(sm.dim_param, seed=seed)
        self.params = params
        self.f = f
        self.sm = sm
        self.obs_stats = obs_stats
        self.m = np.linalg.norm(obs_stats)
        self.sigma = sm.sigma
        
    def eval(self, params, log=True):
        logl = -0.5 * np.log(2 * np.pi * self.sigma) -0.5 * ([self.f(param) for param in params] - self.m) ** 2 / (self.sigma ** 2)
        if np.any(np.abs(params) > self.params.prior_width):
            logl = float('-inf')
        return logl if log else np.exp(logl)

In [4]:
class ModelConcat(BaseSimulator):
    def __init__(self, mlist, seed=None):
        dim_params = [ m.dim_param for m in mlist ]
        super().__init__(dim_param=np.sum(dim_params), seed=seed)
        
        self.mlist = mlist
        self.idx_list = np.insert(np.cumsum(dim_params), 0, 0)
        
    def gen_single(self, params):
        params_list = [ params[self.idx_list[i]:self.idx_list[i+1]] for i in range(len(self.mlist)) ]
        rets = [ m.gen_single(p) for m, p in zip(self.mlist, params_list) ]
        
        data_ret = np.concatenate([ r['data'] for r in rets ])
        return { 'data' : data_ret }

In [5]:
def plot_pdf_2d(params, dist, label = None):
    lims = 2 * np.linalg.norm(params.obs_stats)
    xlist = np.linspace(-lims, lims, 80)
    ylist = np.linspace(-lims, lims, 80)
    
    fig = plt.figure(figsize=(8,4))
    ax = fig.add_subplot(111, projection='3d')
    
    X, Y = np.meshgrid(xlist, ylist)
    xys = np.asarray((X.ravel(), Y.ravel())).T
    
    if label is not None:
        fig.suptitle(label)
        
    Z = dist.eval(xys, log=False).reshape(X.shape)
    ax.plot_surface(X, Y, Z, cmap=cm.plasma, rstride=1, cstride=1, linewidth=0, antialiased=False)
        
    ax.set_xlabel("$\\theta_1$")
    ax.set_ylabel("$\\theta_2$")
    ax.set_zlabel("$p(\\theta)$")
    plt.show()
    
    return fig

In [6]:
def run_SNPE(params):
    seeds = np.arange(0, params.n_copies) + params.seed 
    mlist = [ ShapeModel(params.f, sigma=params.sigma, ndim=params.ndim, seed=s) for s in seeds ]
    mc = ModelConcat(mlist, seed=params.seed)
    m = mlist[0]
    p = dd.Uniform(-params.prior_width * np.ones(params.n_copies * params.ndim), params.prior_width * np.ones(params.n_copies * params.ndim), seed=params.seed)
    s = ds.Identity(params.n_copies)
    g = dg.Default(model=mc, prior=p, summary=s)
    
    res = infer.SNPE(g, obs=[params.obs_stats], 
                     n_hiddens=params.n_hiddens, 
                     n_components=params.n_components, 
                     seed=params.seed, 
                     prior_mixin=params.prior_mixin, 
                     svi=False,
                     convert_to_T=default_params.convert_to_T,
                     verbose=True)
    
    logs, tds, posteriors = res.run(n_train=params.n_train, n_rounds = params.n_rounds, round_cl=params.round_cl, es_thresh=0.01)

    return mc, p, s, res, posteriors, logs, tds

In [7]:
default_params = ParameterSet({})

default_params.seed = 39
default_params.ndim = 5
default_params.n_components = 30
default_params.n_copies = 1
default_params.n_summ = 4
default_params.n_hiddens = 2 * [ default_params.n_components * default_params.ndim * default_params.n_copies * 5 ]

default_params.sigma = 0.3 * np.eye(default_params.n_summ)

default_params.true_params = np.tile(2 * np.eye(default_params.ndim)[0], default_params.n_copies)
default_params.obs_stats = np.zeros(default_params.n_copies * default_params.n_summ)
default_params.prior_width = 2

default_params.prior_mixin = 0.0
default_params.convert_to_T = 1
default_params.round_cl = 10

default_params.kernel = None

default_params.n_train = 300000
default_params.n_rounds = 2

cubic = lambda x: (x[0] - 2) * (x[0] + 2) * x[0] - x[1]
ushape = lambda x: x[0] ** 2 - x[1] - 1
ring = lambda x: np.linalg.norm(x) - 1

ring_ring = lambda x: [ring(x[:2]), ushape(x[2:]), x[1] - x[2]]

letter_S = lambda x: [ 3 * (x[0] ** 3 - x[0] - 0.5 * x[1]) ] 
letter_N = lambda x: [ (x[1] ** 2 - 3) * (x[0] + x[1]) ]
letter_P = lambda x: [ (x[1] + (x[0] - 1.2) ** 2 - 1) * (x[1] + 1) ] 
letter_E = lambda x: [ (x[0] ** 3 - 2 * x[0]) * (x[1] + 1.5) ]

text_SNPE = lambda x: letter_S(x[0:2]) + letter_N(x[1:3]) + letter_P(x[2:4]) + letter_E(x[3:])
default_params.f = text_SNPE

In [8]:
std_params = default_params

m, p, s, res, posteriors, logs, tds = run_SNPE(std_params)

A Jupyter Widget




A Jupyter Widget




A Jupyter Widget




A Jupyter Widget




A Jupyter Widget


Cannot predict posterior after round 0 due to NaNs


In [9]:
for posterior in posteriors:
    if posterior is None:
        continue
        
    fig, _ = plot_pdf(posterior,  lims=[-default_params.prior_width, default_params.prior_width])

In [28]:
seeds = default_params.seed + np.arange(default_params.n_copies)
gts = [ ShapeDistribution(default_params.f, default_params, m.mlist[0], default_params.obs_stats, seed=s) for s in seeds ]
gt = dd.MixedDistribution(gts, seed=default_params.seed)
    

In [29]:
def mcmc(start, pdf, var=0.1, nsamples=50000):
    ndim = len(start)
    S = np.eye(ndim) * var
    mu_current = start
    l_current = pdf.eval([mu_current], log=False)
    
    samples = np.empty((nsamples, ndim))
    for i in tqdm(range(nsamples)):
        mu_proposal = np.random.multivariate_normal(mean=mu_current, cov=S)

        l_proposal = pdf.eval([mu_proposal], log=False)
        p_accept = l_proposal / l_current
        accept = np.random.rand() < p_accept
        
        if accept:
            mu_current = mu_proposal
            assert(np.all(np.abs(mu_current) <= default_params.prior_width))
            l_current = l_proposal
        
        samples[i,:] = mu_current
        
    return samples

start = np.zeros_like(default_params.true_params)

gt.eval([start])
samples = mcmc(start, pdf=gt)

A Jupyter Widget




ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

In [31]:
%debug

> [0;32m<ipython-input-29-27e4408a7589>[0m(15)[0;36mmcmc[0;34m()[0m
[0;32m     13 [0;31m        [0maccept[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mrandom[0m[0;34m.[0m[0mrand[0m[0;34m([0m[0;34m)[0m [0;34m<[0m [0mp_accept[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m[0;34m[0m[0m
[0m[0;32m---> 15 [0;31m        [0;32mif[0m [0maccept[0m[0;34m:[0m[0;34m[0m[0m
[0m[0;32m     16 [0;31m            [0mmu_current[0m [0;34m=[0m [0mmu_proposal[0m[0;34m[0m[0m
[0m[0;32m     17 [0;31m            [0;32massert[0m[0;34m([0m[0mnp[0m[0;34m.[0m[0mall[0m[0;34m([0m[0mnp[0m[0;34m.[0m[0mabs[0m[0;34m([0m[0mmu_current[0m[0;34m)[0m [0;34m<=[0m [0mdefault_params[0m[0;34m.[0m[0mprior_width[0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0m
[0m
ipdb> p accept
array([[ True, False, False, False],
       [False,  True, False, False],
       [False, False,  True, False],
       [False, False, False,  True]], dtype=bool)
ipdb> p l_proposal
array([

In [256]:
fig, _ = plot_pdf(pdf1=posterior, samples=samples.T, lims=[-default_params.prior_width, default_params.prior_width])


<IPython.core.display.Javascript object>