In [None]:
import jax
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoNormal
from numpyro.infer.initialization import init_to_feasible, init_to_median, init_to_sample, init_to_uniform, init_to_value
from numpyro.optim import Adam
from numpyro.distributions.transforms import biject_to

This script looks at how each of the initialization options in NumPyro work.

We begin by defining a simple Beta prior so that our posterior's support is not $\mathbb{R}$.

In [None]:
def model(y=None):
    theta = numpyro.sample("theta", dist.Beta(0.5, 0.5))
    numpyro.sample("obs", dist.Bernoulli(0, theta), obs=y)
optimizer = Adam()
elbo = Trace_ELBO()
guide = AutoNormal(model)



In [None]:
def demonstrate_init_loc_fn():
    """Show how init_loc_fn affects the initial variational mean."""
    
    rng_key = random.PRNGKey(42)
    
    print("=" * 70)
    print("Demonstrating init_loc_fn for Beta distribution")
    print("=" * 70)
    print()
    
    # Get the bijector for Beta (to convert between constrained and unconstrained)
    beta_dist = dist.Beta(2.0, 5.0)
    transform = biject_to(beta_dist.support)
    
    # Test different initialization strategies
    strategies = [
        ("init_to_uniform (default)", init_to_uniform),
        ("init_to_value(0.1)", init_to_value(values={"z": 0.1})),
        ("init_to_value(0.5)", init_to_value(values={"z": 0.5})),
        ("init_to_value(0.9)", init_to_value(values={"z": 0.9})),
    ]
    
    for strategy_name, init_strategy in strategies:
        print(f"\n{strategy_name}:")
        print("-" * 70)
        
        # Create guide with this init strategy
        guide = AutoNormal(model, init_loc_fn=init_strategy)
        
        # Initialize SVI
        optimizer = Adam(step_size=0.01)
        svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
        
        # Get initial state
        svi_state = svi.init(rng_key)
        
        # Get initial parameters (these are in CONSTRAINED space)
        init_params = svi.get_params(svi_state)
        
        # The loc parameter for the Beta site
        loc_constrained = init_params["z_auto_loc"]
        
        # Convert to unconstrained space
        loc_unconstrained = transform.inv(loc_constrained)
        
        print(f"  Initial variational mean (CONSTRAINED space, in (0,1)): {loc_constrained:.6f}")
        print(f"  Initial variational mean (UNCONSTRAINED space, in R):   {loc_unconstrained:.6f}")
        
        # Also show what the transform does
        print(f"  Transform check: constrained -> unconstrained -> constrained")
        print(f"    {loc_constrained:.6f} -> {loc_unconstrained:.6f} -> {transform(loc_unconstrained):.6f}")
        
        # Show the scale parameter (always in unconstrained space for optimization)
        # But we need to get it from the optimizer state
        optim_params = svi_state.optim_state
        # The scale is stored separately, let's check the guide's internal state
        # Actually, let's just show what we can easily access
        
    print("\n" + "=" * 70)
    print("Key Observation:")
    print("=" * 70)
    print("init_loc_fn sets values in CONSTRAINED space (0,1) for Beta.")
    print("These are then converted to UNCONSTRAINED space (R) for optimization.")
    print("The 'loc' parameter you see in init_params is in CONSTRAINED space.")
    print("=" * 70)

if __name__ == "__main__":
    demonstrate_init_loc_fn()
