# Test ADVI

It works!

https://github.com/martiningram/jax_advi

We modify the likelihood to a pseudolikelihood thats include the log det Jacobian of the transformation of prior $N(0,1)$.

With better posterior covariances (like a MAP approximation): linear response variational Bayes. Operates on the result of a vanilla ADVI run.

https://martiningram.github.io/vi-with-good-covariances/

## Elliptical slice sampling

Another technique which could make use of our $N(0,1)$ priors.
Elliptical slice sampling is a MCMC method for problems with Gaussian priors. [Murray2010]
For VI we use $N(0,1)$ priors which are then transformed and for nested sampling we use $U(0,1)$ priors which are then transformed.

In [1]:
%run init.ipy
from dgf import core
from dgf import isokernels
from lib import constants
from dgf import bijectors
from dgf.prior import lf
from dgf.prior import source

import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb

import jax_advi

2022-12-18 11:47:33.116003: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory


In [2]:
# Sample from the source and noise priors and fit a prior in the `z` domain
BOUNDS = constants.SOURCE_BOUNDS.copy()
BOUNDS['noise_power'] = [constants.NOISE_FLOOR_POWER, 1.]
PARAMS = ['var', 'r', 'T', 'Oq', 'noise_power']
NUMPARAMS = len(PARAMS) # == 5

source_params_ppf = source._get_source_params_ppf(
    constants.SOURCE_BOUNDS, constants.SOURCE_MEDIAN, source.RHO
)

# No correlations at all b/c synthetic prior
samples = np.hstack([
    source_params_ppf(rand(int(1e4), NUMPARAMS-1)),
    source.noise_power_ppf(rand(int(1e4), 1), constants.NOISE_FLOOR_DB)
])

AttributeError: module 'dgf.prior.source' has no attribute '_get_source_params_ppf'

In [None]:
import corner

fig = corner.corner(
    samples,
    labels=PARAMS,
    show_titles=True,
    smooth=.2
)

In [None]:
z = np.log(samples)
z_mean = np.mean(z, axis=0)
z_cov = np.cov(z.T)
z_sigma = np.sqrt(np.diag(z_cov))
z_corr = np.diag(1/z_sigma) @ z_cov @ np.diag(1/z_sigma)
L_z_corr = np.linalg.cholesky(z_corr)

z_bounds = np.log(np.array([
    BOUNDS[k] for k in PARAMS
]))

In [None]:
import dynesty
import scipy.stats

static_bijector = tfb.Chain([
    tfb.Exp(), tfb.SoftClip(
        z_bounds[:,0], z_bounds[:,1], z_sigma
    )
])

def getprior(rescale):
    L = np.diag(rescale*z_sigma) @ L_z_corr
    prior = tfd.TransformedDistribution(
        distribution=tfd.MultivariateNormalTriL(
            loc=z_mean,
            scale_tril=L
        ),
        bijector=static_bijector
    )
    return prior

def loglike(rescale, data=samples):
    prior = getprior(rescale)
    lp = np.sum(prior.log_prob(data))
    return -np.inf if np.isnan(lp) else float(lp)

def ptform(
    u,
    rescale_prior=scipy.stats.expon(scale=1.)
):
    return rescale_prior.ppf(u)

ndim = NUMPARAMS
sampler = dynesty.NestedSampler(loglike, ptform, ndim, nlive=ndim*5)
sampler.run_nested()
results = sampler.results

In [None]:
# Rescale (stretch) factors
results.samples[-1,:]

In [None]:
# Very good model of the data without problems at the edges
priorml = getprior(results.samples[-1,:])

fig = corner.corner(
    np.array(priorml.sample(100000,seed=jax.random.PRNGKey(1387))),
    labels=PARAMS,
    show_titles=True,
    smooth=.2
)

In [None]:
def prior_bijector(mean, cov, **kwargs):
    color = bijectors.color_bijector(mean, cov)
    return tfb.Chain([static_bijector, color])

L = np.diag(results.samples[-1,:]*z_sigma) @ L_z_corr
z_cov_ml = L @ L.T

bijector = prior_bijector(z_mean, z_cov_ml)

In [None]:
prior = tfd.TransformedDistribution(
    distribution=tfd.MultivariateNormalDiag(scale_diag=jnp.ones(NUMPARAMS)),
    bijector=bijector,
    name='Prior'
)

In [None]:
# Check the convention: we need to **SUBTRACT** the forward log det jacobian

# Using TransformedDistribution
theta, theta_lp = prior.experimental_sample_and_log_prob(seed=jax.random.PRNGKey(10))
display(theta, theta_lp)

# Using bijector explicitly 
z = bijector.inverse(theta)
z_lp = tfd.MultivariateNormalDiag(scale_diag=jnp.ones(NUMPARAMS)).log_prob(z)
z_lp - bijector.forward_log_det_jacobian(z).squeeze(), theta_lp # Equal?

In [None]:
def unpack(theta):
    a = jnp.vstack([theta['var'], theta['r'], theta['T'], theta['Oq'], theta['noise_power']])
    return a.T

def pack(a):
    var, r, T, Oq, noise_power = a.T
    return dict(var=var, r=r, T=T, Oq=Oq, noise_power=noise_power)

In [None]:
test_samples = prior.sample(100000, seed=jax.random.PRNGKey(54544))

import corner
corner.corner(np.array(test_samples), labels=['var', 'r', 'T', 'Oq', 'noise_power'])

pack(test_samples)

In [None]:
# We have annoying behavior of the density functions
# near the boundaries (and nans if you get too close)
# There *is* mass at the boundaries, as can be seen from
# samples, but not overwhelmingly so
MEDIANS = constants.SOURCE_MEDIAN.copy()
MEDIANS['noise_power'] = constants.db_to_power(-30.)

def probe_param_bounds(param, n=1000):
    lower, upper = BOUNDS[param]
    values = jnp.linspace(lower, upper, n)

    a = unpack(MEDIANS)
    a = np.repeat(a[None,:], n, axis=0)
    theta_test = pack(a)
    theta_test[param] = values
    a = unpack(theta_test)
    
    return values, prior.log_prob(a)

def test_param_bounds(param, n=1000):
    values, lp = probe_param_bounds(param, n=n)
    plot(values, lp)
    title(param)
    ylabel(f'log prior({param}|median values of other params)')
    xlabel(param)
    show()

test_param_bounds('var')
test_param_bounds('r')
test_param_bounds('T')
test_param_bounds('Oq')
test_param_bounds('noise_power')

In [None]:
from jax.scipy.stats import norm
from jax.experimental.host_callback import call

def minus_inf_if_nan(x):
    return jax.lax.cond(jnp.isnan(x), lambda: -jnp.inf, lambda: x)

def calculate_prior(packed_z): # Standardnormal
    z = unpack(packed_z)
    return jnp.sum(norm.logpdf(z))

def calculate_likelihood(theta, sample, config):
    R = core.kernelmatrix_root_gfd_oq(
        config['kernel'],
        theta['var'],
        theta['r'],
        sample['t'],
        config['kernel_M'],
        theta['T'],
        theta['Oq'],
        config['c'],
        config['impose_null_integral']
    )
    logl = core.loglikelihood_hilbert(R, sample['u'], theta['noise_power'])
    return logl

def calculate_pseudo_likelihood(packed_z, sample, config):
    """
    We perform a hack here and enforce the impact of the log det jacobian **of the prior**
    by summing it with the likelihood. This is correct as log prior and log likelihood
    are being summed to calculate the log posterior. But it smells a bit because
    technically the log volume correction should only be applied to the prior, as the
    likelihood is not a density with respect to the parameters, only to the data.
    """
    z = unpack(packed_z)
    
    # Calculate the actual likelihood L(theta) = p(sample|theta)
    theta = pack(bijector.forward(z))
    log_like = calculate_likelihood(theta, sample, config).squeeze()
    
    # Doesn't work with gradients
    #def printdebug(theta):
        #print({k: float(v) for k, v in theta.items()})

    #call(printdebug, theta)
    
    # Calculate the log volume factors of the transforms `z -> theta`
    # Note the minus sign here!! This is the correct way.
    prior_log_det_jac = -bijector.forward_log_det_jacobian(z).squeeze()
    
    return minus_inf_if_nan(log_like + prior_log_det_jac)

In [None]:
lf_samples = source.get_lf_samples()

z = randn(5)
packed_z = pack(z)
theta = pack(bijector.forward(z))

sample = lf_samples[1]

config = dict(
    kernel_name = 'Matern32Kernel',
    kernel_M = 128,
    use_oq = True,
    impose_null_integral = True
)

assert config['use_oq'] == True
config['kernel'] = isokernels.resolve(config['kernel_name'])
config['c'] = constants.BOUNDARY_FACTOR

calculate_prior(packed_z), calculate_pseudo_likelihood(packed_z, sample, config)

In [None]:
from functools import partial

log_prior_func = jax.jit(calculate_prior)
log_like_func = jax.jit(partial(calculate_pseudo_likelihood, sample=sample, config=config))

packed_z = pack(randn(5))

display(jax.value_and_grad(log_prior_func)(packed_z))
display(jax.value_and_grad(log_like_func)(packed_z))

In [None]:
# We get infs in the objective function and nans in the gradients
# when one of the `M` samples happens to hit a bound of the
# `bijector`. If it
# happens, all evaluations return infs and nans. So the problems
# are due to the problematic behavior of the log det jac of
# the prior transformation from N(0,I) to the actual model
# parameters at the bounds
# But our log likelihood is well behaved!!
#
# How to fix??
# Either put `M = 3` very low or (and we did it here)
# rescale the covariance `cov_z` to be much smaller such that
# the bounds never get reached.
import jax_advi.advi

theta_shapes = {
    'var': (),
    'r': (),
    'T': (),
    'Oq': (),
    'noise_power': ()
}

result = jax_advi.advi.optimize_advi_mean_field(
    theta_shapes,
    log_prior_func,
    log_like_func,
    verbose=True,
    M=50,
    #var_param_inits={'mean': (0.,1.), 'log_sd': (0.,1.)},
    opt_method="L-BFGS-B" # This is faster and seems to be leap succesfully over early local minima
)

In [None]:
pack(bijector.forward(unpack(result['free_means'])))

In [None]:
posterior = pack(bijector.forward(unpack(result['draws'])))

hist(np.array(posterior['T']))

corner.corner(np.array(unpack(posterior)), labels=PARAMS, 
              #range=[BOUNDS[k] for k in PARAMS]
             );

The open quotient `Oq` of the LF model correlates only moderately with the OQ as we see it. Namely we see the OQ as a "hard" close where the DGF waveform is zero. The LF model has an exponential return phase and a "soft" close such that the `Oq` is quite fuzzily defined. So we cannot expect our inferred `Oq` to correspond with the `Oq` of LF, because of this soft return phase. Our implementation of `Oq` is just dividing the pitch period into a hard zero (closed) phase and a nonzero (open) phase.

In [None]:
sample

In [None]:
## Try better errorbars from <https://github.com/martiningram/jax_advi/blob/main/examples/Tennis%20example.ipynb>

from jax_advi.lrvb import compute_lrvb_covariance, get_posterior_draws_lrvb

lrvb_free_sds, lrvb_cov_mat = compute_lrvb_covariance(
    result['final_var_params_flat'], result['objective_fun'], result['shape_summary'], batch_size=8)

In [None]:
matshow(lrvb_cov_mat); colorbar();

In [None]:
lrvb_free_sds # GOOD: These are several times larger than the vanilla free_sds

In [None]:
result['free_sds']