## code for testing analytical division for MoG proposals with well-separated modes

Assume mixture-of-Gaussian proposals 
$$ \tilde{p}(\theta) = \sum_k \tilde{\alpha}_{k} \mathcal{N}(\theta | \tilde{\mu}_{k}, \tilde{\Sigma}_{k}) $$

and mixture-of-Gaussian uncorrected posterior estimates
$$ q_\phi(\theta|x_0) = \sum_k \alpha_{\phi,k} \mathcal{N}(\theta | \mu_{\phi,k}, \Sigma_{\phi,k}) $$

with the same number of components $k = 1,2, \ldots$.

If only one proposal component has mass anywhere close to a posterior component, the prior/proposal-prior correction can just ignore all other proposal components - hence we can use Papamakarios' analytical correction step. 

Since there are several posterior components to be corrected, we however cannot just rely on Gaussians being well-normalized, and actually have work out the relative mixture weights $\alpha_k$ for each component.  

We illustrate on the example case of two proposal/posterior modes on one-dimensional $\theta$. 

In [None]:
%%capture 
import util
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import delfi.distribution as dd
from numpy.linalg import det as det


# tested mostly for 1-dimensional problems
def get_alphas(Pm0s, P0s, a0s, 
               Pmts, Pts, ats, 
               Pmps, Pps, aps): 
    # Pm: precision * mean
    # P: precision
    # a: mixture weights
    # **0 : prior
    # **t : proposal ('tilde')
    # **p : posterior (uncorrected)
    
    a = np.zeros(len(mps))
        
    for i in range(a.size):
    
        sPm = Pmps[i] + Pm0s[0] - Pmts[i]  
        sP  =  Pps[i] +  P0s[0] -  Pts[i]

        a[i]  = posterior.a[i] / proposal.a[i]
        a[i] *= np.sqrt( det(Pps[i]) * det(P0s[0]) / det(Pts[i]) / det(sP) )        
        a[i] *= np.exp(-1./2. * (Pmps[i].dot(mps[i]) + Pm0s[0].dot(m0s[0]) - Pmts[i].dot(mts[i])) )
        a[i] *= np.exp( 1./2. * sPm.dot(np.linalg.inv(sP)).dot(sPm) )

    return a


In [None]:
plt.figure(figsize=(16,16))

dists  = [1., 2., 4., 8.] # roughly gives 'separated'-ness of proposal components (assuming unit variances)

for i in range(len(dists)):

    dist = dists[i]
    
    # prior 
    m0s  = [2 * np.ones((1))]
    S0s  = [20*np.eye(1)]
    a0s  = np.ones(len(m0s))/len(m0s)
    
    # proposal ('\tilde')
    mts = [-dist/2 * np.ones((1)), dist/2 * np.ones((1))]
    Sts = [ np.eye(1)/0.8,    np.eye(1)/1.3]
    ats  = [0.3, 0.7] #np.ones(len(mts))/len(mts)
    
    # posterior
    mps = [- 0.8 * dist/2 * np.ones((1)), 1.3 * dist/2 * np.ones((1))]
    Sps = [ np.eye(1)/3., np.eye(1)/3.5]
    aps  = np.ones(len(mps))/len(mps)

    # rewrite into more useful representation (also internally used by dd.Gaussian for analytic correction) 
    P0s  = [np.linalg.inv(S) for S in S0s]
    Pm0s = [P.dot(m) for (P,m) in zip(P0s, m0s)]
    Pts = [np.linalg.inv(S) for S in Sts]
    Pmts = [P.dot(m) for (P,m) in zip(Pts, mts)]    
    Pps = [np.linalg.inv(S) for S in Sps]
    Pmps = [P.dot(m) for (P,m) in zip(Pps, mps)]    
    
    prior     = dd.MoG(a = a0s, ms=m0s, Ss=S0s)
    proposal  = dd.MoG(a = ats, ms=mts, Ss=Sts)
    posterior = dd.MoG(a = aps, ms=mps, Ss=Sps)

    plt.subplot(2,2,i+1)
    
    ths = np.linspace(-10,10, 300).reshape(1,-1).T
    p = np.exp(prior.eval(ths))
    p /= p.sum()*(ths[2]-ths[1])
    plt.plot(ths, p, 'b')

    p = np.exp(proposal.eval(ths))
    p /= p.sum()*(ths[2]-ths[1])
    plt.plot(ths, p, 'g')

    p = np.exp(posterior.eval(ths))
    p /= p.sum()*(ths[2]-ths[1])
    plt.plot(ths, p, 'k')

    p = prior.eval(ths,log=False) / proposal.eval(ths,log=False) * posterior.eval(ths, log=False)
    p /= p.sum()*(ths[2]-ths[1])
    plt.plot(ths, p, 'r', linewidth=1.5)

    x1 = posterior.xs[0]/proposal.xs[0]*prior.xs[0]
    x2 = posterior.xs[1]/proposal.xs[1]*prior.xs[0]
    ac = get_alphas(Pm0s, P0s, a0s, 
                    Pmts, Pts, ats, 
                    Pmps, Pps, aps)
    print(ac)
    ac = ac/ac.sum()
    
    posterior_x = dd.MoG(a = ac, ms=[x1.m, x2.m], Ss = [x1.S, x2.S])
    p = np.exp(posterior_x.eval(ths))
    p /= p.sum()*(ths[2]-ths[1])
    plt.plot(ths, p, 'm--', linewidth=2)

    plt.title('d` = ' + str(dist))
    #plt.yticks([])
    
    if i == 0:
        plt.legend(['prior', 'proposal', 'q_phi', 'real posterior', 'analytic correction'], loc=2)
    plt.yticks([])

#plt.savefig('analytical_correction_MoGproposal_1d_example.pdf')
plt.show()