In [2]:
import numpy as np
import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value
from numpyro.optim import Adam

This script looks at how each of the initialization options in NumPyro work.

We begin by defining a simple Beta prior so that our posterior's support is not $\mathbb{R}$.

In [10]:
def model():
    numpyro.sample("theta", dist.Beta(2.0, 5.0))

# here are all of the different initializers
# so that we can check init of the stddev, we set that to be 0.5 * (i+1)

initializer_guides = {
    "default": AutoNormal(model),
    "init_to_uniform": AutoNormal(model, init_loc_fn=init_to_uniform, init_scale=0.1),
    "init_to_feasible": AutoNormal(model, init_loc_fn=init_to_feasible, init_scale=0.5),
    "init_to_median": AutoNormal(model, init_loc_fn=init_to_median, init_scale=1.0),
    "init_to_sample": AutoNormal(model, init_loc_fn=init_to_sample, init_scale=1.5),
    "init_to_value": AutoNormal(model, init_loc_fn=init_to_value(values={"theta": 0.5}), init_scale=2.5),
}

# these are placeholders to be able to initialize SVI later on, so values are not important.
optimizer = Adam(step_size=1e-3)
elbo = Trace_ELBO()

Now let's look at each of the initialization function options: init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value

In each case, we show that the standard deviation may also be controlled with init_scale.

We run the default and then each of the initialization options.

In [None]:
default_guide = AutoNormal(model)
svi_default = SVI(model, default_guide, optimizer, elbo)
rng_key = jax.random.PRNGKey(0)
default_state = svi_default.init(rng_key)
default_params = svi_default.get_params(default_state)
print(f"Default params: mean is {default_params['theta_auto_loc']:.3f} and std is {default_params['theta_auto_scale']:.3f}")


for key, strat in initializer_guides.items():
    guide = initializer_guides[key]
    svi = SVI(model, guide, optimizer, elbo)
    rng_key = jax.random.PRNGKey(0)
    # get initial values of svi since we can't call run with fewer than 1 step
    state = svi.init(rng_key)
    params = svi.get_params(state)
    print(f"-----\n {key} initialization results:\n-----")
    print(f"'Params' values after initializing: mean is {params['theta_auto_loc']:.3f} and std is {params['theta_auto_scale']:.3f}")

    # we would now like to understand whether these are in the constrained or unconstrained space.
    samples = guide.sample_posterior(rng_key, params, sample_shape=(100000,))
    # print empirical mean and std of samples
    empirical_mean = jnp.mean(samples['theta'])
    empirical_std = jnp.std(samples['theta'])
    print(f"Empirical mean and std of samples from posterior:\n mean is {empirical_mean:.3f} and std is {empirical_std:.3f}")
    # let's sample from the params returned and transform and see what the empirical mean and std are
    unconstrained_samples = jax.random.normal(rng_key, (100000,)) * params['theta_auto_scale'] + params['theta_auto_loc']
    constrained_samples = jax.nn.sigmoid(unconstrained_samples)
    constrained_empirical_mean = jnp.mean(constrained_samples)
    constrained_empirical_std = jnp.std(constrained_samples)
    print(f"Empirical mean and std of transformed samples drawn from the 'params' values above: \n mean is {constrained_empirical_mean:.3f} and std is {constrained_empirical_std:.3f}")


Default params: mean is 0.178 and std is 0.100
-----
 default initialization results:
-----
'Params' values after initializing: mean is 0.178 and std is 0.100
Empirical mean and std of samples from posterior:
 mean is 0.544 and std is 0.025
Empirical mean and std of transformed samples drawn from the 'params' values above: 
 mean is 0.544 and std is 0.025
-----
 init_to_uniform initialization results:
-----
'Params' values after initializing: mean is 0.178 and std is 0.100
Empirical mean and std of samples from posterior:
 mean is 0.544 and std is 0.025
Empirical mean and std of transformed samples drawn from the 'params' values above: 
 mean is 0.544 and std is 0.025
-----
 init_to_feasible initialization results:
-----
'Params' values after initializing: mean is 0.000 and std is 0.500
Empirical mean and std of samples from posterior:
 mean is 0.500 and std is 0.118
Empirical mean and std of transformed samples drawn from the 'params' values above: 
 mean is 0.500 and std is 0.118
---

According to the exploration above, we can conclude that the init_loc_fn parameter is setting the unconstrained mean, and the init_scale parameter is also setting the unconstrained std.