Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding some basic VI approximation and fitting routine #397

Open
junpenglao opened this issue Nov 12, 2022 · 9 comments · Fixed by #433
Open

Adding some basic VI approximation and fitting routine #397

junpenglao opened this issue Nov 12, 2022 · 9 comments · Fixed by #433
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed important vi Variational Inference

Comments

@junpenglao
Copy link
Member

Copying over from #392 (comment)

After #392, we should add the 2 most basic VI algorithm: meanfield and full rank ADVI [1]. Below is a working example of Meanfield ADVI:

import jax
import jax.numpy as jnp
from jax.scipy import stats

def gen_meanfield_logprob(params):
    mu_param, rho_param = params
    sigma_param = jax.tree_map(jnp.exp, rho_param)
    def meanfield_logprob(position):
        logq_pytree = jax.tree_map(
            stats.norm.logpdf, position, mu_param, sigma_param
            )
        logq = jax.tree_map(jnp.sum, logq_pytree)
        return jax.tree_util.tree_reduce(jnp.add, logq)
    return meanfield_logprob

# gen_meanfield_logprob(init_params)(init_position)

def meanfield_sample(
    rng_key, meanfield_param, num_samples: int
    ):
    if not isinstance(num_samples, tuple):
        num_samples = (num_samples,)
    mu_param, rho_param = meanfield_param
    sigma_param = jax.tree_map(jnp.exp, rho_param)
    mu_flatten, unravel_fn = jax.flatten_util.ravel_pytree(mu_param)
    sigma_flatten, _ = jax.flatten_util.ravel_pytree(sigma_param)
    flatten_sample = jax.random.normal(
        rng_key, num_samples + mu_flatten.shape
        ) * sigma_flatten + mu_flatten
    if len(num_samples) == 0:
        return unravel_fn(flatten_sample)
    return jax.vmap(unravel_fn)(flatten_sample)

# meanfield_sample(rng, init_params, ())

def meanfield_approximate(rng, init_params, log_prob_fn, optimizer, sample_size=5, num_steps=200):
    def meanfield_approximate_step(
        state, rng_key_sample
        ):
        params, opt_state = state
        def kl_fn(params):
            sample = meanfield_sample(rng_key_sample, params, sample_size)
            logq = gen_meanfield_logprob(params)(sample)
            logp = log_prob_fn(sample)
            return (logq - logp).mean()
        # compute KL divergence
        elbo, grad = jax.value_and_grad(kl_fn)(params)
        updates, opt_state = optimizer.update(grad, opt_state, params)
        params = jax.tree_map(
            lambda p, u: p + u, params, updates
            )
        return (params, opt_state), elbo
    
    def run_optimization(init_params):
        opt_state = optimizer.init(init_params)
        state = (init_params, opt_state)
        rng_key = jax.random.split(rng, num_steps)
        return jax.lax.scan(
            meanfield_approximate_step, state, rng_key
            )
    
    return run_optimization(init_params)

Fitting a model looks like:

import matplotlib.pyplot as plt
import numpy as np

import optax
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

rng = jax.random.PRNGKey(0)

seed0, seed1, rng = jax.random.split(rng, 3)
X = jax.random.normal(seed0, (100, 98))
y = X @ np.arange(98) + jax.random.normal(seed1, (100,))

@tfd.JointDistributionCoroutineAutoBatched
def model():
    sigma = yield tfd.HalfNormal(5.0, name='sigma')
    mu = yield tfd.Normal(0.0, 1.0, name='mu')
    beta = yield tfd.Sample(tfd.Normal(mu, sigma), X.shape[-1], name='beta')
    yield tfd.Normal(X @ beta, 1.0, name="y")

# init_position = model.sample(seed=rng)
pinned = model.experimental_pin(y=y)
init_position = pinned.sample_unpinned(seed=rng)

bijectors = pinned.experimental_default_event_space_bijector()
def log_prob_fn(unbound_param):
    param = bijectors.forward(unbound_param)
    log_det_jacobian = bijectors.forward_log_det_jacobian(unbound_param)
    return pinned.unnormalized_log_prob(param) + log_det_jacobian
# This is just one way to do it. We could also use a flattened array to represent mu and rho
mu_param = jax.tree_map(jnp.zeros_like, init_position)
rho_param = jax.tree_map(jnp.zeros_like, init_position)
init_params = (mu_param, rho_param)

optimizer = optax.chain(optax.clip(10.), optax.adam(1.))
output = meanfield_approximate(rng, init_params, log_prob_fn, optimizer)

[1] https://arxiv.org/abs/1603.00788

@junpenglao junpenglao added enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed labels Nov 12, 2022
@rlouf rlouf added vi Variational Inference important labels Nov 12, 2022
@rlouf
Copy link
Member

rlouf commented Nov 12, 2022

In #392 we define a VIAlgorithm, and here we would need to define a new base type ParametrizedVIAlgorithm base type.

@rlouf rlouf mentioned this issue Dec 11, 2022
5 tasks
@xidulu
Copy link
Contributor

xidulu commented Dec 14, 2022

@rlouf
I think I can take charge of this, but just for sure:

We assume the log_prob_fn (i.e. log p(x,z) ) in BlackJax takes in a real value flattened array (rather than a dict or something on constrained space) right

@rlouf
Copy link
Member

rlouf commented Dec 14, 2022

Great! No there is no such assumption in the library (or at least shouldn't be), we try to support PyTree states as much as we can.

@xidulu
Copy link
Contributor

xidulu commented Dec 26, 2022

@rlouf

To follow the design principle of blackjax, I believe VI should also has an API of the form below? :

new_state, info =  kernel(rng_key, state)

which would perform one optimization step for the ELBO.

@rlouf
Copy link
Member

rlouf commented Dec 26, 2022

As you can see with the pathfinder implementation, Blackjax treats VI differrently from MCMC algorothms.

The idea is that you first fit an approximation to the target density, and then sample from this approximation with something like (in peudo-code):

approx, info = approximate(rng_key, position)
samples = sample(sample_key, approx, num_samples)

I think at the higher-level the API will always be something more or less like this. We can consider a kernel-like lower interface for some algorithms if it makes sense. But again, I am no VI expert and open to suggestions.

@LarsKarbach
Copy link

Can someone then, give a minimal working example for the Mean Field VI? This would be helpful also for the refactoring of the pathfinder API in #465 and the implementation of the full rank approach, i believe.

@rlouf
Copy link
Member

rlouf commented Jan 29, 2023

MFVI is implemented here and full ranks is being implemented in #479. The refactoring of Pathfinder is a bit involved, but up for grabs :)

@LarsKarbach
Copy link

I understand. Altough i would argue that it would be helpful to get a foot in the door, if one wants to help to develop VI further. Could be as easy as having a multivariate normal and evaluating mean field and full rank. For the current implementation i don't see immediately how the pseudo-code you provided is implemented in the library.

@xidulu
Copy link
Contributor

xidulu commented Jan 30, 2023

@LarsKarbach I understand your point and I really wish there could be a template for implementing VI variants (e.g. as simple as providing a log_q function and a sampling function) but the APIs are still in the very initial stage. At this moment, there are still lots of boilerplate code in the implementation... Probably after the fullrank VI's PR got merged in, we could start working on simplifying the VI implementation process.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed important vi Variational Inference
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants