In [23]:
import numpy as np
import jax
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value
from numpyro.optim import Adam

This script looks at how each of the initialization options in NumPyro work.

We begin by defining a simple Beta prior so that our posterior's support is not $\mathbb{R}$.

In [25]:
def model():
    numpyro.sample("theta", dist.Beta(2.0, 5.0))

# here are all of the different initializers
# so that we can check init of the stddev, we set that to be 0.5 * (i+1)

initializer_guides = {
    "init_to_feasible": AutoNormal(model, init_loc_fn=init_to_feasible, init_scale=0.5),
    "init_to_median": AutoNormal(model, init_loc_fn=init_to_median, init_scale=1.0),
    "init_to_sample": AutoNormal(model, init_loc_fn=init_to_sample, init_scale=1.5),
    "init_to_uniform": AutoNormal(model, init_loc_fn=init_to_uniform, init_scale=2.0),
    "init_to_value": AutoNormal(model, init_loc_fn=init_to_value(values={"theta": 0.5}), init_scale=2.5),
}

optimizer = Adam(step_size=1e-16)
elbo = Trace_ELBO()

Now let's look at each of the initialization function options: init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value

In each case, we show that the standard deviation may also be controlled with init_scale.

We run the default and then each of the initialization options.

In [42]:
default_guide = AutoNormal(model)
svi_default = SVI(model, default_guide, optimizer, elbo)
rng_key = jax.random.PRNGKey(0)
default_state = svi_default.init(rng_key)
default_params = svi_default.get_params(default_state)
print(f"Default params: mean is {default_params['theta_auto_loc']} and std is {default_params['theta_auto_scale']}")


for key, strat in initializer_guides.items():
    guide = initializer_guides[key]
    svi = SVI(model, guide, optimizer, elbo)
    rng_key = jax.random.PRNGKey(0)
    # get initial values of svi since we can't call run with fewer than 1 step
    state = svi.init(rng_key)
    params = svi.get_params(state)
    param_dict[key] = params
    print(f"-----\n {key} initialization results:\n-----")
    print(f"'Params' values after initializing: mean is {params['theta_auto_loc']} and std is {params['theta_auto_scale']}")

    # we would now like to understand whether these are in the constrained or unconstrained space.
    samples = guide.sample_posterior(rng_key, params, sample_shape=(100000,))
    # print empirical mean and std of samples
    empirical_mean = jnp.mean(samples['theta'])
    empirical_std = jnp.std(samples['theta'])
    print(f"Empirical mean and std of samples from posterior:\n mean is {empirical_mean} and std is {empirical_std}")
    # let's sample from the params returned and transform and see what the empirical mean and std are
    unconstrained_samples = jax.random.normal(rng_key, (100000,)) * params['theta_auto_scale'] + params['theta_auto_loc']
    constrained_samples = jax.nn.sigmoid(unconstrained_samples)
    constrained_empirical_mean = jnp.mean(constrained_samples)
    constrained_empirical_std = jnp.std(constrained_samples)
    print(f"Empirical mean and std of transformed samples drawn from the 'params' values above: \n mean is {constrained_empirical_mean} and std is {constrained_empirical_std}")


Default params: mean is 0.2227025032043457 and std is 0.09999997913837433
-----
 init_to_feasible initialization results:
-----
'Params' values after initializing: mean is 0.0 and std is 0.4999999701976776
Empirical mean and std of samples from posterior:
 mean is 0.500145673751831 and std is 0.11838821321725845
Empirical mean and std of transformed samples drawn from the 'params' values above: 
 mean is 0.5006129741668701 and std is 0.11804007738828659
-----
 init_to_median initialization results:
-----
'Params' values after initializing: mean is -1.205141544342041 and std is 1.0
Empirical mean and std of samples from posterior:
 mean is 0.2682666778564453 and std is 0.17252814769744873
Empirical mean and std of transformed samples drawn from the 'params' values above: 
 mean is 0.2687114477157593 and std is 0.17225044965744019
-----
 init_to_sample initialization results:
-----
'Params' values after initializing: mean is -0.36108484864234924 and std is 1.5
Empirical mean and std of s

According to the exploration above, we can conclude that the init_loc_fn parameter is setting the unconstrained mean, and the init_scale parameter is also setting the unconstrained std.