# No more rounds!


### SNPE, CDELFI and rounds

- SNPE and CDELFI are methods to fit a facotrizing model $q_\phi(\theta|x)p(x)$ for the joint density $p(x,\theta)$ of parameters $\theta$ and summary statistics $x$ defined by a stochastic simulation model $p(x|\theta)$ and prior $p(\theta)$. 
- both methods support proposal distributions $\tilde{p}(\theta)$ to substitute the prior $p(\theta)$ during fitting, allowing to focus the density learning on regions of $p(x,\theta)$ where $x \approx x_0$ for actually observed data $x_0$. 
- coming up with such proposal distributions $\tilde{p}(\theta)$ is cumbersome - the standard approach of both methods is to run a sequence of 'rounds' - each being a full model fit of its own that initiates $\tilde{p}(\theta) \leftarrow q_\phi(\theta \ | \ x_0)$ with the posterior estimate from the previous round.


- rounds suck.


- The simple assignment $\tilde{p}(\theta) \leftarrow q_\phi(\theta \ | \ x_0)$ introduces hard-to-control dynamics across rounds.
- initial $\tilde{p}(\theta)$ might be poor and not focus well on $x_0$, but early rounds nonetheless need to be run with large-enough simulated data-sets $\{(x_n, \theta_n)\}_{n=1}^N$ to generate a better proposal for next round - this in practice means sampling enough data to fit a full mixture density network!

- a different approach is to parametrize the proposal and try to learn both proposal $p_\psi(\theta)$ and conditional density $q_\phi(\theta \ | \ x)$ jointly. 

- but what loss term should we use to choose a good proposal parameter $\psi$?

### parametrized proposal distributions

- another problem here however also comes already from the SNPE/CDELFI loss, 
$$\mathcal{L}(\theta,\psi) = - \ D_\mbox{KL}(p_\psi(\theta) \ p(x|\theta) \ || \ q_\phi(\theta|x) p(x)) = \int p_\psi(\theta) p(x|\theta) \log q_\phi(\theta | x) dx d\theta + \int p_\psi(\theta) p(x|\theta) \log \frac{p(x)}{p_\psi(\theta)p(x|\theta)} dx d\theta $$
When only optimizing for $\phi$, treating $\psi$ constant, this loss allows to fully avoid evaluating the unknown densities $p(x | \theta)$ and especially $p(x) =  \int p(x|\theta) p_\psi(\theta) = f(x, \psi)$ through MC-approximations of the integrals. It however no longer does so for variable $\psi$. Likelihood and marginal evaluations need to be estimated, which requires costly simulations. This is particularly true for the marginal density $p(x)$ that would in practice require many likelihood estimates $p(x|\theta_i), \theta_i \sim p_\psi(\theta)$ to approximate well with Monte Carlo. 

- it is interesting to note that the rhs term that is constant in $\phi$ (but not in $\psi$), 
$$ - \int p_\psi(\theta) p(x|\theta) \log \frac{p(x)}{p_\psi(\theta)p(x|\theta)} dx d\theta := H[p(X,\theta),p(X)]$$
is the cross-entropy between target joint distribution $p(\theta)p(x|\theta)$ and the problematic marginal $p(x)$, meaning that this term can provide information to fit a parametrized model $p_\omega(x)$ to approximate the marginal $p(x)$ induced by $p_\psi(\theta)$!
- $H[p_\psi(X,\theta),p_\omega(X)] = - \int p_\psi(\theta,x) \log \frac{p_\omega(x)}{p_\psi(\theta|x)p(x)} dx d\theta = D_{KL}(p(x) ||  p_\omega(x)) + \int p_\psi(\theta,x) \log p_\psi(\theta|x) dx d\theta$ is minimized w.r.t. $\omega$ if $p_\omega(x) = p(x)$ almost everywhere. 

- thanks to this link between proposal $p_\psi$ and marginal $p_\omega$, having a good marginal estimate in turn can now help us choose a good proposal prior: the basis of most variational approximations is indeed to try to maximimize the marginal data-likelihood $p_\omega(x_0)$.  

### adaptive proposal loss

$$\mathcal{L}(\theta,\psi,\omega) = - \ D_\mbox{KL}(p_\psi(\theta) \ p(x|\theta) \ || \ q_\phi(\theta|x) p_\omega(x)) + \log p_\omega(x_0) \\ 
= \int p_\psi(\theta) p(x|\theta) \log q_\phi(\theta | x) dx d\theta + \int p_\psi(\theta) p(x|\theta) \log \frac{p_\omega(x)}{p_\psi(\theta)p(x|\theta)} dx d\theta + \log p_\omega(x_0)$$

- only unparametrized part of the loss is the likelihood $p(x|\theta)$, the rest can be jointly optimized with (stochastic) gradient descent!
- gradient w.r.t. $\phi$ is virtually unchanged relative to SNPE / CDELFI
- gradient w.r.t. $\psi$ tries to balance maximizing conditional probabilities $q_\phi(\theta|x)$ while keeping the induced marginal $p(x)$ close to $p_\omega(x)$.  
- gradient w.r.t. $\omega$ tries to balance staying close to $p(x)$ and maximizing the observed-data likelihood $p_\omega(x_0)$.  
- remember $\frac{\partial}{\partial{}\psi} \int p_\psi(\theta) f(\theta) d\theta = \int p_\psi(\theta) f(\theta)\frac{\partial}{\partial{}\psi} \log p_\psi(\theta) d\theta$ under mild constraints, i.e. we can use Monte Carlo for gradients wrt. the proposal
- we are left with an evaluation of the log-likelihood in the new loss! In practice, we will need the conditional entropy $H_{x|\theta}[X|\theta_n]$ over summary statistics at sampled parameters $\theta_n$. Since $H_{x|\theta}$ maps only from parameters $\theta$ into the reals (whereas $p(x|\theta)$ is defined over the joined set of all $(x,\theta)$), this quantity may be easier to aquire than the likelihood, esp. when there are much fewer parameters than summary statistics. 


### questions
- do the three major loss terms ($q_\phi$, $H[p(X,\theta)|p(X)], \log p_\omega(x_0)$) work together in concert as desired, or will one dominate (e.g. will the proposal overadapt to the current MDN $\phi$ rather than try to follow and give 'wiggle-space' $p_\omega(x_0)$)? Will we need to introduce weights, e.g. $\lambda \log p_\omega(x_0)$?  
- what are good model assumptions for $p_\omega$? How important is capturing the data marginal $p(x)$, how important is matching proposal*likelihood ?
- how to best estimate the conditional entropies $H[X|\theta]$ from simulations? Repeated sampling for each $\theta$? Fit another network?
- can this be extended with importance sampling? The above scheme requires analytical correction for the final proposal distribution (after convergence), but can we relax this when re-introducing importance weights? If yes, does the ability to choose another proposal distribution after each individual batch-gradient allow us to get more stable importance sampling schemes?
- can this be extended to SVI? Note that $p_\omega(x)$ essentially just completes the parametrization of the joint-data likelihood $q_\phi(\theta_n|x_n)p_\omega(x_n) = p(\theta_n,x_n | \phi, \omega)$, but that a constantly updated proposal $p_\psi(\theta)$ destroys the notion of having a single fixed data-set. How do we do Bayesian online-learning when the data-distribution is known to be a (stochastically) moving target? 
- no more rounds? 

No more rounds!

# toy problem setup
- how much can be done analytical?

prior: 

$p(\theta) = \mathcal{N}(\theta \ | \ 0, \eta^2)$

proposal prior: 

$p_\psi(\theta) = \mathcal{N}(\theta \ | \ \nu, \xi^2)$

simulator: 

$p(x \ | \ \theta) =  \mathcal{N}(x \ | \ \theta, \sigma^2)$

analytic posteriors: 

$p(\theta \ | \ x) = \mathcal{N}(\theta \ | \frac{\eta^2}{\eta^2 + \sigma^2} x, \eta^2 - \frac{\eta^4}{\eta^2 + \sigma^2})$ 

$\tilde{p}(\theta \ | \ x) = \mathcal{N}(\theta \ | \frac{\xi^2}{\xi^2 + \sigma^2} x + \frac{\sigma^2}{\xi^2 + \sigma^2} \nu, \xi^2 - \frac{\xi^4}{\xi^2 + \sigma^2})$

variational marginal:

$p_\omega(x) = \mathcal{N}(x \ | \ \tau, \rho^2)$

Data:

$(x_n, \theta_n) \sim p(\theta) p(x \ | \ \theta) = \mathcal{N}( (x_n, \theta_n) \ | \ (\nu, \nu), 
\begin{pmatrix}
\xi^{2} + \sigma^{2} &  \xi^{2}  \\
\xi^{2} & \xi^{2}  \\
\end{pmatrix})$

Model: 

$ q_\phi(\theta_n | x_n) = \mathcal{N}(\theta_n \ | \ \mu_\phi(x_n), \sigma^2_\phi(x_n))$

$ (\mu_\phi(x), \sigma^2_\phi(x)) = MDN_\phi(x) = \begin{pmatrix} \beta \\ 0 \end{pmatrix} x + \begin{pmatrix} \alpha \\ \gamma^2 \end{pmatrix}$

Loss: 

$\mathcal{L}(\phi, \psi, \omega) = - \lambda \ D_\mbox{KL}(p_\psi(\theta) \ p(x|\theta) \ || \ q_\phi(\theta|x) p_\omega(x)) + \log p_\omega(x_0)$

Gradients: 

$\frac{\partial\mathcal{L}}{\partial\phi} = \frac{\partial}{\partial\phi} \int p(x|\theta) p_\psi(\theta) \ \log q_\phi(\theta\ | \ x) \ dx d\theta \approx \frac{1}{N} \sum_n \frac{\partial}{\partial\phi} \log q_\phi(\theta_n | x_n)$

$\frac{\partial\mathcal{L}}{\partial\omega} = \frac{\partial}{\partial\omega} \log p_\omega(x_0) + \lambda \ \frac{\partial}{\partial\omega}  \int p(x|\theta) p_\psi(\theta) \ \log p_\omega(x) \ dx d\theta \approx  \frac{\partial}{\partial\omega} \log p_\omega(x_0) + \frac{\lambda}{N} \sum_n \frac{\partial}{\partial\omega} \log p_\omega(x_n)$

$\frac{\partial\mathcal{L}}{\partial\psi} 
%= \frac{\partial}{\partial\psi} \int p(x|\theta) p_\psi(\theta) \left( \log q_\phi(\theta|x) + \log p_\omega(x) - \log p(x|\theta) p_\psi(\theta) \right) dx d\theta 
= \int p(x|\theta) p_\psi(\theta) \left[ \frac{\partial}{\partial\psi} \log p_\psi(\theta)\right] \left( \log q_\phi(\theta|x) + \log p_\omega(x) \right) dx d\theta - \int p_\psi(\theta) \left[ \frac{\partial}{\partial\psi} \log p_\psi(\theta)\right] H[X|\theta] d\theta - \frac{\partial}{\partial\psi} H_{\psi}[\theta] \\ 
\approx \frac{1}{N}\sum_n \left( \log q_\phi(\theta_n|x_n) + \log p_\omega(x_n) - H[X|\theta_n] \right) \frac{\partial}{\partial\psi} \log p_\psi(\theta_n) - \frac{\partial}{\partial\psi} H_{\psi}[\theta]$



## stuff below is old and mainly in for recycling gradients w.r.t. $\phi$

In [15]:
%%capture
import delfi.distribution as dd
import delfi.generator as dg
import delfi.inference as infer
import delfi.summarystats as ds
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

from scipy.stats import multivariate_normal as mvn
from delfi.utils.progress import no_tqdm, progressbar

from delfi.simulator.Gauss import Gauss
    
def gauss_weights(params, stats, mu_y, Sig_y):
    
    y = np.hstack((stats, params))
    return mvn.pdf(x=y, mean=mu_y.reshape(-1), cov=Sig_y, allow_singular=True)   


def gauss_weights_eps0(params, stats, mu_y, Sig_y):
    """ stable version in case eps^2 is giant - stats.mvn return nonsense here """
    # Note: making use of the fact that covariances are zero for normal weights in SNPE/CDELFI MLE solutions
    
    x = -0.5 * (params-mu_y[1])**2 / Sig_y[1,1] # would like to use mvn.pdf, but that one freaks  
    return np.exp( x.reshape(-1) )              # out for 1D problems with negative (co-)variance


def sel_gauss_implementation(eps2, thresh=1000): 
    
    return gauss_weights_eps0 if eps2 > thresh else gauss_weights


#def studentT_weights(params, stats, mu_y, Sig_y):
#    
#    raise NotImplementedError
#    
#    
#def studentT_weights_eps0(params, stats, mu_y, Sig_y, df=3):
#    """ stable version in case eps^2 is giant - stats.mvn return nonsense here """
#    # Note: making use of the fact that covariances are zero for normal weights in SNPE/CDELFI MLE solutions
#    
#    exponent = -(df+1)/2 
#    return (1 + (params-mu_y[1])**2/(df*Sig_y[1,1]))**exponent
#
#
#def sel_studentT_implementation(eps2, thresh=1000): 
#    
#    return studentT_weights_eps0 if eps2 > thresh else studentT_weights
#

def get_weights_fun(eps2, thresh=1000, proposal_form='normal'):
    
    if proposal_form=='normal':
        selector = sel_gauss_implementation 
    elif proposal_form=='studentT':
        selector = sel_studentT_implementation 
        
    return selector(eps2, thresh)

def get_weights(proposal_form, eta2, ksi2, eps2, x0, nu, stats, params, df=3): 
    
    assert proposal_form in ('normal', 'studentT')
    
    if proposal_form == 'normal':
        
        eta2p = 1/(1/eta2 - 1/ksi2)
        Sig_y = np.array([[eps2,0], [0,eta2p]])    
        mu_y = np.array([ [x0[0]], [eta2/(eta2-ksi2)*nu]])

        comp_weights =get_weights_fun(eps2, thresh=1000, proposal_form=proposal_form)
        
        normals = comp_weights(data[0], data[1], mu_y, Sig_y)
        
    if proposal_form == 'studentT':

        exponent = -(df+1)/2 
        proposal_pdf = (1 + (params-nu)**2/(df*ksi2))**exponent
        prior_pdf    = mvn.pdf(x=params, mean=0., cov=eta2)
        normals = prior_pdf / proposal_pdf
        if eps2 < 1000:
            calibration_kernel_pdf = mvn.pdf(x=stats, mean=x0, cov=eps2)
            normals *= calibration_kernel_pdf
            
    return normals



def dL_dalpha( params, stats, normals, beta, gamma2, alphas):

    return -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta * stats.reshape(-1,1) - alphas.reshape(1,-1))/gamma2 * stats.reshape(-1,1)).sum(axis=0)


def dL_dbeta( params, stats, normals, alpha, gamma2, betas):

    return -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta * stats.reshape(-1,1) - alphas.reshape(1,-1))/gamma2).sum(axis=0)

def dL_dgamma2( params, stats, normals, alpha, beta, gamma2s):

    tmp = (params.reshape(-1,1) - beta*stats.reshape(-1,1) - alpha)**2 / gamma2s.reshape(1,-1)
    return 1/gamma2s.reshape(1,-1) * (normals.reshape(-1,1) * (1 - out)).sum(axis=0)
    
def alpha(params, stats, normals):
    
    N = normals.size    

    Eo  = (normals * params).sum()
    Eox = (normals * stats * params).sum()
    Ex2 = (normals * stats**2).sum()
    Ex  = (normals * stats).sum()
    E1  = normals.sum()
    
    #ahat = (normals * (Ex2 * params - Eox * stats)).sum()
    #ahat /= (E1 * Ex2 - Ex**2)
    
    ahat = (Eo - Eox/Ex2 * Ex) / (E1 - Ex**2/Ex2)
    
    return ahat

def beta(params, stats, normals, ahat=None):

    ahat = alpha(params, stats, normals) if ahat is None else ahat

    Eox = (normals * stats * params).sum()
    Ex2 = (normals * stats**2).sum()
    Ex  = (normals * stats).sum()
    
    bhat = (Eox - ahat * Ex) / Ex2
    
    return bhat
    
def gamma2(params, stats, normals, ahat=None, bhat=None):

    ahat = alpha(params, stats, normals) if ahat is None else ahat
    bhat = beta(params, stats, normals, ahat) if bhat is None else bhat

    gamma2hat = (normals*(params - ahat - bhat * stats )**2).sum() / normals.sum()
    
    return gamma2hat


def analytic_div(out, eta2, nus, ksi2s):
    """ analytic correction of onedimensional Gaussians for proposal priors"""
    # assumes true prior to have zero mean!
    # INPUTS:
    # - out: 3D-tensor: 
    #        1st axis gives ksi2s (proposal variances), 
    #        2nd axis gives number of experiments/runs/fits
    #        3nd axis is size 2: out[i,j,0] Gaussian mean, out[i,j,0] Gaussian variance
    # - eta2:  prior variance (scalar)
    # - nus:   vector of proposal prior means
    # - ksi2s: vector of proposal prior variances
    
    # OUTPUTS
    # - out_: 3D tensor of proposal-corrected posterior means and variances
    
    out_ = np.empty_like(out)
    for i in range(out_.shape[0]):
        
        # precision and precision*mean
        P = 1/out[i,:,1]
        Pm = P * out[i,:,0]

        # multiply with prior
        P = P + 1/eta2
        Pm = Pm + 0/eta2

        # divide by proposal
        P = P - 1/ksi2s[i]
        Pm = Pm - nu/ksi2s[i]

        out_[i,:,:] = np.vstack((Pm/P, 1/P)).T

    return out_

n_bins = 50 # number of bins for plotting


# define problem setup


In [12]:

## problem setup ##

n_params = 1

assert n_params == 1 # cannot be overstressed: everything in this notebook goes downhill sharply otherwise

sig2 = 1.0/9. # likelihood variance
eta2 = 1.0     # prior variance
eps2 = 1e20    # calibration kernel width (everything above a certain threshold will be treated as 'uniform')

# pick observed summary statistics
x0 = 0.8 * np.ones(1) #_,obs = g.gen(1) 

# prior and analytic (!) likelihood & posterior
m = Gauss(dim=n_params, noise_cov=sig2)
p = dd.Gaussian(m=0. * np.ones(n_params), 
                S=eta2 * np.eye(n_params))
post   = dd.Gaussian(m = np.ones(n_params) * eta2/(eta2+sig2)*x0[0], 
                     S=eta2 - eta2**2 / (eta2 + sig2) * np.eye(n_params))    

## simulation setup ##

n_fits = 500  # number of MLE fits (i.e. dataset draws), each single-round fits with pre-specified proposal!
N      = 500  # number of simulations per dataset

# set proposal priors (one per experiment)
ksi2s = np.array([0.01, 0.1, 0.5, 0.999]) * eta2  # proposal variance
nus = eta2/(eta2+sig2)*x0[0]* np.ones(len(ksi2s))              # proposal mean


res = {'normal' : np.zeros((len(ksi2s), n_fits,2)),
       't_df10' : np.zeros((len(ksi2s), n_fits,2)),
       't_df3'  : np.zeros((len(ksi2s), n_fits,2)),
       'cdelfi' : np.zeros((len(ksi2s), n_fits,2)),
       'sig2' : sig2,
       'eta2' : eta2,
       'eps2' : eps2,
       'ksi2s' : ksi2s,
       'nus' : nus,
       'x0' : x0,
      }

# numerical checks for gradients
- tbd. lots.


In [None]:
N = 3
track_rp = True
proposal_form = 'normal'
df = None

nu = 0.
ksi2 = 0.5 * eta2

ppr = dd.Gaussian(m=nu * np.ones(n_params), 
                S=ksi2 * np.eye(n_params))
s = ds.Identity()
g = dg.Default(model=m, prior=ppr, summary=s)

seed = 42
g.model.seed = seed
g.prior.seed = seed
g.seed = seed

data = g.gen(N, verbose=False)
params, stats = data[0].reshape(-1), data[1].reshape(-1)

normals = get_weights(proposal_form, eta2, ksi2, eps2, x0, nu, stats, params, df=df) if track_rp else np.ones(N)/N


# numerically check $\frac{\partial}{\partial{}\alpha}$

- $\frac{\partial}{\partial\alpha}$ is being difficult here. Analytic solution $\hat{\alpha}$ still fails to numberically set the stated partial derivative $\frac{\partial\mathcal{L}}{\partial{}\alpha}(\hat{\alpha})$ to zero ...
- Obtained $\hat{\alpha}$ however are pretty much sensible though (correct 'ballpark')

In [None]:
alpha_hat = np.array(alpha(params, stats, normals))

gamma2_ = post.std**2
alphas = np.linspace(-0.09, -0.02, 100000)

beta_ = beta(params, stats, normals, alpha_hat)
out_hat = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta_ * stats.reshape(-1,1) - alpha_hat)/gamma2_).sum(axis=0)
out = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta_ * stats.reshape(-1,1) - alphas.reshape(1,-1))/gamma2_).sum(axis=0)
plt.plot(alphas, out)
plt.show()


# (numerical solution, analytical solution, derivate evaluated at analytical solution (should be zero-ish) ) = 
alphas[np.argmin(np.abs(out))], alpha_hat, out_hat

# numerically check $\frac{\partial}{\partial{}\beta}$

In [None]:
alpha_ = alpha_hat
gamma2_ = post.std**2
betas = np.linspace(0., 1., 1000)
out = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - betas.reshape(1,-1) * stats.reshape(-1,1) - alpha_)/gamma2_ * stats.reshape(-1,1)).sum(axis=0)
plt.plot(betas, out)
plt.show()

beta_hat = beta(params, stats, normals, ahat=alpha_)
out_hat = -2*(normals.reshape(-1,1) * (params.reshape(-1,1) - beta_hat * stats.reshape(-1,1) - alpha_)/gamma2_ * stats.reshape(-1,1)).sum(axis=0)

# (numerical solution, analytical solution, derivate evaluated at analytical solution (should be zero-ish) ) = 
betas[np.argmin(np.abs(out))], beta_hat, out_hat

# numerically check $\frac{\partial}{\partial{}\gamma^2}$

In [None]:
alpha_ = alpha_hat
beta_ = beta_hat
gamma2s = np.linspace(0.0009, 0.001, 1000)

# something off with below (hard-coded...) gradients now. Outcommenting numerical solution for now!

tmp = (params.reshape(-1,1) - beta_*stats.reshape(-1,1) - alpha_)**2 / gamma2s.reshape(1,-1)
out = 1/gamma2s.reshape(-1,) * (normals.reshape(-1,1) * (1 - tmp)).sum(axis=0)

plt.plot(gamma2s, out)
plt.show()

gamma2_hat = gamma2(params, stats, normals, ahat=alpha_, bhat=beta_)
tmp_ = (params.reshape(-1,1) - beta_*stats.reshape(-1,1) - alpha_)**2 / gamma2_hat
out_hat = 1/gamma2_hat * (normals.reshape(-1,1) * (1 - tmp_)).sum(axis=0)

#gamma2s *= np.nan 
#out = np.zeros_like(gamma2s)

# (numerical solution, analytical solution, derivate evaluated at analytical solution (should be zero-ish) ) = 
gamma2s[np.argmin(np.abs(out))], gamma2_hat, out_hat