# analytical broadsword approach

- analytically tractable problem: Gaussian prior, Gaussian proposal, _linear-Gaussian likelihood_
- analytically tractable MDN: linear-affine network
- analytically tractable gradients and closed-form solution for MDN parameters for given dataset


TO DO:
- check gradients again
- seems like the proposal-adjusted finite-size effective joint distribution (eq. 4, p. 11 in JakobsNotes.pdf) is badly off right now (_currently not in use_!)
- type down the CDELFI version of the below stuff and import analytical division for direct comparison 


prior: 

$p(\theta) = \mathcal{N}(\theta \ | \ 0, \eta^2)$

proposal prior: 

$\tilde{p}(\theta) = \mathcal{N}(\theta \ | \ \nu, \xi^2)$

simulator: 

$p(x \ | \ \theta) =  \mathcal{N}(x \ | \ \theta, \sigma^2)$

analytic posteriors: 

$p(\theta \ | \ x) = \mathcal{N}(\theta \ | \frac{\eta^2}{\eta^2 + \sigma^2} x, \eta^2 - \frac{\eta^4}{\eta^2 + \sigma^2})$ 

$\tilde{p}(\theta \ | \ x) = \mathcal{N}(\theta \ | \frac{\xi^2}{\xi^2 + \sigma^2} x + \frac{\sigma^2}{\xi^2 + \sigma^2} \nu, \xi^2 - \frac{\xi^4}{\xi^2 + \sigma^2})$

Data:

$(x_n, \theta_n) \sim p(\theta) p(x \ | \ \theta) = \mathcal{N}( (x_n, \theta_n) \ | \ (0, \nu), 
\begin{pmatrix}
\xi^{2} + \sigma^{2} &  \xi^{2}  \\
\xi^{2} & \xi^{2}  \\
\end{pmatrix})$

Loss: 

$ \mathcal{L}(\phi) = \sum_n \frac{{p}(\theta_n)}{\tilde{p}(\theta_n)} K_\epsilon(x_n | x_0) \ \log q_\phi(\theta_n | x_n)$

Model: 

$ q_\phi(\theta_n | x_n) = \mathcal{N}(\theta_n \ | \ \mu_\phi(x_n), \sigma^2_\phi(x_n))$

$ (\mu_\phi(x), \sigma^2_\phi(x)) = MDN_\phi(x) = \begin{pmatrix} \beta \\ 0 \end{pmatrix} x + \begin{pmatrix} \alpha \\ \gamma^2 \end{pmatrix}$

Gradients: 

$\mathcal{N}_n := \mathcal{N}(x_n, \theta_n \ | \ \mu_y, \Sigma_y)$

$\Sigma_y = 
\begin{pmatrix}
\epsilon^2  &  0  \\
0 & \left( \eta^{-2} - \xi^{-2} \right)^{-1}  \\
\end{pmatrix}$


$\mu_y = \begin{pmatrix} x_0  \\ \frac{\eta^2}{\eta^2 + \xi^2}\nu \end{pmatrix}$

$\frac{\partial{}\mathcal{L}}{\partial{}\alpha} = -2 \sum_n \mathcal{N}_n \frac{\theta_n - \mu_\phi(x_n)}{\sigma^2_\phi(x_n)}$

$\frac{\partial{}\mathcal{L}}{\partial{}\beta} = -2 \sum_n \mathcal{N}_n \frac{\theta_n - \mu_\phi(x_n)}{\sigma^2_\phi(x_n)} x_n$

$\frac{\partial{}\mathcal{L}}{\partial{}\gamma^2} 
= \sum_n \mathcal{N}_n \left( \frac{1}{\sigma^2_\phi(x_n)} 
- \frac{\left(\theta_n - \mu_\phi(x_n) \right)^2}{\sigma^4_\phi(x_n)} \right) 
= \frac{1}{\gamma^2} \sum_n \mathcal{N}_n \left( 1 
- \frac{\left(\theta_n - \mu_\phi(x_n) \right)^2}{\gamma^2} \right) $

Optima: 

$\hat{\alpha} = 
\frac{\sum_n \mathcal{N}_n \left(\theta_n - \frac{\sum_m \mathcal{N}_m \theta_m x_m}{\sum_m \mathcal{N}_m x_m^2} x_n \right)}{\sum_n \mathcal{N}_n - \frac{\left( \sum_n \mathcal{N}_n x_n \right)^2}{\sum_n \mathcal{N}_n x_n^2}}$

$\hat{\beta} = 
\frac{\sum_n \mathcal{N}_n \theta_n x_n - \hat{\alpha} \sum_n \mathcal{N}_n x_n}{\sum_n \mathcal{N}_n x_n^2}$

$\hat{\gamma}^2 = 
\frac{\sum_n \mathcal{N}_n \left( \theta_n - \hat{\alpha} - \hat{\beta} x_n \right)^2}{\sum_n \mathcal{N}_n}$

In [None]:
%%capture
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.summarystats as ds
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from scipy.stats import multivariate_normal as mvn

from delfi.simulator.Gauss import Gauss

def gauss_weights(params, stats, mu_y, Sig_y):
    
    y = np.hstack((stats, params))
    return mvn.pdf(x=y, mean=mu_y.reshape(-1), cov=Sig_y, allow_singular=True)   

def gauss_weights_eps0(params, stats, mu_y, Sig_y):
    """ stable version in case eps^2 is giant - stats.mvn return nonsense here """
    
    x = -0.5 * (params-mu_y[1])**2 / Sig_y[1,1] # would like to use mvn.pdf, but that one freaks  
    return np.exp( x.reshape(-1) )              # out for 1D problems with negative (co-)variance


def dL_dalpha( params, stats, normals, beta, gamma2, alphas):

    return -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta * stats.reshape(-1,1) - alphas.reshape(1,-1))/gamma2 * stats.reshape(-1,1)).sum(axis=0)


def dL_dbeta( params, stats, normals, alpha, gamma2, betas):

    return -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta * stats.reshape(-1,1) - alphas.reshape(1,-1))/gamma2).sum(axis=0)

def dL_dgamma2( params, stats, normals, alpha, beta, gamma2s):

    tmp = (params.reshape(-1,1) - beta*stats.reshape(-1,1) - alpha)**2 / gamma2s.reshape(1,-1)
    return 1/gamma2s.reshape(1,-1) * (normals.reshape(-1,1) * (1 - out)).sum(axis=0)
    
def alpha(params, stats, normals):
    
    N = normals.size    

    Eo  = (normals * params).sum()
    Eox = (normals * stats * params).sum()
    Ex2 = (normals * stats**2).sum()
    Ex  = (normals * stats).sum()
    E1  = normals.sum()
    
    #ahat = (normals * (Ex2 * params - Eox * stats)).sum()
    #ahat /= (E1 * Ex2 - Ex**2)
    
    ahat = (Eo - Eox/Ex2 * Ex) / (E1 - Ex**2/Ex2)
    
    return ahat

def beta(params, stats, normals, ahat=None):

    ahat = alpha(params, stats, normals) if ahat is None else ahat

    Eox = (normals * stats * params).sum()
    Ex2 = (normals * stats**2).sum()
    Ex  = (normals * stats).sum()
    
    bhat = (Eox - ahat * Ex) / Ex2
    
    return bhat
    
def gamma2(params, stats, normals, ahat=None, bhat=None):

    ahat = alpha(params, stats, normals) if ahat is None else ahat
    bhat = beta(params, stats, normals, ahat) if bhat is None else bhat

    gamma2hat = (normals*(params - ahat - bhat * stats )**2).sum() / normals.sum()
    
    return gamma2hat

n_params = 1
seed = 42

sig2 = 1.0
eta2 = 1.0
eps2 = 1e20

if eps2 > 1000:
    gauss_weights = gauss_weights_eps0

x0 = 0.8 * np.ones((1,1)) #_,obs = g.gen(1)

assert n_params ==  1
m = Gauss(dim=n_params, seed=seed, noise_cov=sig2)
p = dd.Gaussian(m=0. * np.ones(n_params), 
                S=eta2 * np.eye(n_params),
                seed=seed)
post   = dd.Gaussian(m = np.ones(n_params) * eta2/(eta2+sig2)*x0[0], 
                     S=eta2 - eta2**2 / (eta2 + sig2) * np.eye(n_params))    

#else:
#    Sig_y = (eps2*eta2*sig2)/(eps2+eta2+sig2) * np.array([[1/eta2+1/sig2, 1/sig2],[1/sig2, 1/eps2+1/sig2]])
#    mu_y = np.dot( Sig_y, np.array([x0**2/eps2, 0]) )


In [None]:
from delfi.utils.progress import no_tqdm, progressbar

n_fits = 100
N = 5000

ksi2s = np.array([0.48]) * eta2
nus = 0.38 * np.ones(len(ksi2s))

plt.figure(figsize=(4*len(ksi2s),8))
n_bins = 20

track_rp = True # track real posterior: if False, will compare with 'proposal-posterior'

m_m, m_v, M_m, M_v, hh_m, hh_v = np.inf,np.inf,-np.inf,-np.inf,-np.inf,-np.inf
for i in range(len(ksi2s)):
    nu, ksi2 = nus[i], ksi2s[i]
    ppr = dd.Gaussian(m=nu * np.ones(n_params), 
                    S=ksi2 * np.eye(n_params),
                    seed=seed)
    postpr = dd.Gaussian(m = np.ones(n_params) * (ksi2/(ksi2+sig2)*x0[0] + sig2/(ksi2+sig2)*nu), 
                         S=ksi2 - ksi2**2 / (ksi2 + sig2) * np.eye(n_params))
    eta2p = 1/(1/eta2 - 1/ksi2)
    Sig_y = np.array([[eps2,0], [0,eta2p]])    
    mu_y = np.array([ [x0[0][0]], [eta2/(eta2-ksi2)*nu]])

    s = ds.Identity()
    g = dg.Default(model=m, prior=ppr, summary=s)
    out_snpe   = np.zeros((n_fits,2))
    pbar = progressbar(total=n_fits)
    desc = 'repeated fits'
    pbar.set_description(desc)
    with pbar:
        for idx_seed in range(n_fits):

            #print( str(idx_seed) + '/' + str(n_fits) )
            seed = 42 + idx_seed
            g.model.seed = seed
            g.prior.seed = seed

            data = g.gen(N, verbose=False)
            params, stats = data[0].reshape(-1), data[1].reshape(-1)

            normals = gauss_weights(data[0], data[1], mu_y, Sig_y) if track_rp else np.ones(N)/N
            ahat =       alpha(params, stats, normals)
            bhat =        beta(params, stats, normals, ahat)
            gamma2hat = gamma2(params, stats, normals, ahat, bhat)

            mu_hat   = ahat + bhat * x0
            sig2_hat = gamma2hat

            out_snpe[idx_seed,:] = (mu_hat, sig2_hat)
            pbar.update(1)

    post_disp = post if track_rp else postpr

    plt.subplot(len(ksi2s), 2, 2*i+1)
    m_m, M_m = np.min((m_m, out_snpe[:,0].min())), np.max((M_m, out_snpe[:,0].max()))
    plt.hist(out_snpe[:,0], bins=np.linspace(m_m, M_m, n_bins), normed=True)
    hh_m = np.max((hh_m, plt.axis()[3]))
    plt.plot([post_disp.mean, post_disp.mean], [0, hh_m], 'r', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([-1,-1]), [0, hh_m/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([0,0]), [0, hh_m/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([1,1]), [0, hh_m/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([-1,1]), [ hh_m/2, hh_m/2], 'g', linewidth=2)
    plt.ylabel('xi^2/eta^2 = ' + str(ksi2/eta2) )
    
    plt.subplot(len(ksi2s),2, 2*i+2)
    m_v, M_v = np.min((m_v, out_snpe[:,1].min())), np.max((M_v, out_snpe[:,1].max()))
    plt.hist(out_snpe[:,1], bins=np.linspace(m_v, M_v, n_bins), normed=True)
    hh_v = np.max((hh_v, plt.axis()[3]))
    plt.plot([post_disp.std**2, post_disp.std**2], [0, hh_v], 'r', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([-1,-1]), [0, hh_v/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([0,0]), [0, hh_v/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([1,1]), [0, hh_v/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([-1,1]), [ hh_v/2, hh_v/2], 'g', linewidth=2)
    #plt.ylabel('posterior variance')


plt.subplot(len(ksi2s),2,1)
plt.title('posterior mean')
plt.subplot(len(ksi2s),2,2)
plt.title('posterior variance')

for i in range(len(ksi2s)):
    plt.subplot(len(ksi2s),2, 2*i+1)
    plt.axis([m_m, M_m, 0, hh_m])
    plt.subplot(len(ksi2s),2, 2*i+2)
    plt.axis([m_v, M_v, 0, hh_v])
plt.show()

In [None]:
out_snpe[:,0].mean(), out_snpe[:,0].std()

In [None]:
out_snpe[:,1].mean(), out_snpe[:,1].std()

In [None]:
nu

In [None]:
ksi2s

In [None]:
from delfi.utils.progress import no_tqdm, progressbar

n_fits = 200
N = 100

ksi2s = np.array([0.1, 0.5, 0.9, 0.999]) * eta2
nus = np.zeros(len(ksi2s)) * x0[0]

plt.figure(figsize=(4*len(ksi2s),8))
n_bins = 20

track_rp = True # track real posterior: if False, will compare with 'proposal-posterior'

m_m, m_v, M_m, M_v, hh_m, hh_v = np.inf,np.inf,-np.inf,-np.inf,-np.inf,-np.inf
for i in range(len(ksi2s)):
    nu, ksi2 = nus[i], ksi2s[i]
    ppr = dd.Gaussian(m=nu * np.ones(n_params), 
                    S=ksi2 * np.eye(n_params),
                    seed=seed)
    postpr = dd.Gaussian(m = np.ones(n_params) * (ksi2/(ksi2+sig2)*x0[0] + sig2/(ksi2+sig2)*nu), 
                         S=ksi2 - ksi2**2 / (ksi2 + sig2) * np.eye(n_params))
    eta2p = 1/(1/eta2 - 1/ksi2)
    Sig_y = np.array([[eps2,0], [0,eta2p]])    
    mu_y = np.array([ [x0[0][0]], [eta2/(eta2-ksi2)*nu]])

    s = ds.Identity()
    g = dg.Default(model=m, prior=ppr, summary=s)
    out_snpe   = np.zeros((n_fits,2))
    pbar = progressbar(total=n_fits)
    desc = 'repeated fits'
    pbar.set_description(desc)
    with pbar:
        for idx_seed in range(n_fits):

            #print( str(idx_seed) + '/' + str(n_fits) )
            seed = 42 + idx_seed
            g.model.seed = seed
            g.prior.seed = seed

            data = g.gen(N, verbose=False)
            params, stats = data[0].reshape(-1), data[1].reshape(-1)

            normals = gauss_weights(data[0], data[1], mu_y, Sig_y) if track_rp else np.ones(N)/N
            ahat =       alpha(params, stats, normals)
            bhat =        beta(params, stats, normals, ahat)
            gamma2hat = gamma2(params, stats, normals, ahat, bhat)

            mu_hat   = ahat + bhat * x0
            sig2_hat = gamma2hat

            out_snpe[idx_seed,:] = (mu_hat, sig2_hat)
            pbar.update(1)

    post_disp = post if track_rp else postpr

    plt.subplot(len(ksi2s), 2, 2*i+1)
    m_m, M_m = np.min((m_m, out_snpe[:,0].min())), np.max((M_m, out_snpe[:,0].max()))
    plt.hist(out_snpe[:,0], bins=np.linspace(m_m, M_m, n_bins), normed=True)
    hh_m = np.max((hh_m, plt.axis()[3]))
    plt.plot([post_disp.mean, post_disp.mean], [0, hh_m], 'r', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([-1,-1]), [0, hh_m/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([0,0]), [0, hh_m/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([1,1]), [0, hh_m/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,0].mean() + out_snpe[:,0].std()*np.array([-1,1]), [ hh_m/2, hh_m/2], 'g', linewidth=2)
    plt.ylabel('xi^2/eta^2 = ' + str(ksi2/eta2) )
    
    plt.subplot(len(ksi2s),2, 2*i+2)
    m_v, M_v = np.min((m_v, out_snpe[:,1].min())), np.max((M_v, out_snpe[:,1].max()))
    plt.hist(out_snpe[:,1], bins=np.linspace(m_v, M_v, n_bins), normed=True)
    hh_v = np.max((hh_v, plt.axis()[3]))
    plt.plot([post_disp.std**2, post_disp.std**2], [0, hh_v], 'r', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([-1,-1]), [0, hh_v/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([0,0]), [0, hh_v/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([1,1]), [0, hh_v/2], 'g', linewidth=2)
    plt.plot(out_snpe[:,1].mean() + out_snpe[:,1].std()*np.array([-1,1]), [ hh_v/2, hh_v/2], 'g', linewidth=2)
    #plt.ylabel('posterior variance')


plt.subplot(len(ksi2s),2,1)
plt.title('posterior mean')
plt.subplot(len(ksi2s),2,2)
plt.title('posterior variance')

for i in range(len(ksi2s)):
    plt.subplot(len(ksi2s),2, 2*i+1)
    plt.axis([m_m, M_m, 0, hh_m])
    plt.subplot(len(ksi2s),2, 2*i+2)
    plt.axis([m_v, M_v, 0, hh_v])
plt.show()

# numerically check $\frac{\partial}{\partial{}\alpha}$

In [None]:
beta_ = eta2/(eta2+sig2)
gamma2_ = post.std**2
alphas = np.linspace(-0.02, -0.01, 100000)

out = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta_ * stats.reshape(-1,1) - alphas.reshape(1,-1))/gamma2_).sum(axis=0)
plt.plot(alphas, out)
plt.show()

alpha_hat = np.array(alpha(params, stats, normals))
out_hat = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta_ * stats.reshape(-1,1) - alpha_hat)/gamma2_).sum(axis=0)


alphas[np.argmin(np.abs(out))], alpha_hat, out_hat

# numerically check $\frac{\partial}{\partial{}\beta}$

In [None]:
alpha_ = 0.
gamma2_ = post.std**2
betas = np.linspace(0., 1., 1000)
out = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - betas.reshape(1,-1) * stats.reshape(-1,1) - alpha_)/gamma2_ * stats.reshape(-1,1)).sum(axis=0)
plt.plot(betas, out)
plt.show()

beta_hat = beta(params, stats, normals, ahat=alpha_)
out_hat = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta_hat * stats.reshape(-1,1) - alpha_)/gamma2_ * stats.reshape(-1,1)).sum(axis=0)


betas[np.argmin(np.abs(out))], beta_hat, out_hat

# numerically check $\frac{\partial}{\partial{}\gamma^2}$

In [None]:
alpha_ = 0.
beta_ = eta2/(eta2+sig2)
gamma2s = np.linspace(0.2, 2, 50)
tmp = (params.reshape(-1,1) - beta_*stats.reshape(-1,1) - alpha_)**2 / gamma2s.reshape(1,-1)
out = 1/gamma2s.reshape(-1,) * (normals.reshape(-1,1) * (1 - tmp)).sum(axis=0)

gamma2_hat = gamma2(params, stats, normals, ahat=alpha_, bhat=beta_)
tmp_ = (params.reshape(-1,1) - beta_*stats.reshape(-1,1) - alpha_)**2 / gamma2_hat
out_hat = 1/gamma2_hat * (normals.reshape(-1,1) * (1 - tmp_)).sum(axis=0)


plt.plot(gamma2s, out)
plt.show()

gamma2s[np.argmin(np.abs(out))], gamma2_hat, out_hat

In [None]:
gamma2hat = gamma2(params, stats, normals, 0, 0.4)
ahat, bhat, gamma2hat

In [None]:

plt.hist(out_snpe[:,1], bins=np.linspace(out_snpe[:,1].min(), out_snpe[:,1].max(), 10))
plt.show()

out_snpe[:,1].mean(), out_snpe[:,1].std()