In [16]:
%load_ext autoreload
%autoreload 2


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


Univariate Gaussian (MeanField and FullRank are equivalent)

    Model:

        Prior: μ ~ N(μ₀, σ₀²)

        Likelihood: y | μ ~ N(μ, σ²) (with σ known)

    True Posterior:

        μ | y ~ N(μₙ, σₙ²), where:

            μₙ = (μ₀/σ₀² + n ȳ/σ²) / (1/σ₀² + n/σ²)

            σₙ² = 1 / (1/σ₀² + n/σ²)

    ADVI Result:

        MeanField and FullRank should both recover μₙ and σₙ².

        ELBO should equal the true log marginal likelihood.


Example 1: 1D Gaussian (Analytical Results)

Let μ₀ = 0, σ₀² = 1, σ² = 1, and data y = [1.0] (n=1):

    True Posterior:

        μₙ = 0.5, σₙ² = 0.5

    Log Marginal Likelihood:
    python

log_p_y = -0.5 * (np.log(4 * np.pi) + 0.5) ≈ -1.2655 - 0.25 = -1.5155

ADVI Check:

    Variational mean ≈ 0.5, variance ≈ 0.5.

    ELBO ≈ -1.5155.

In [54]:
from jax.scipy.stats import norm,beta,expon
from jax_advi.advi import optimize_advi
from jax_advi.constraints import constrain_range_stable,constrain_sigmoid_stable,constrain_exp
from jax_advi.guide import LowRankGuide , MeanFieldGuide,FullRankGuide, VariationalGuide,MAPGuide,LaplaceApproxGuide
import jax
import jax.numpy as jnp
from functools import partial
jax.config.update('jax_enable_x64', True)


def generate_shapes_and_constraints():
    theta_shapes = {'mu': (1,)}
    theta_constraints = {}
    return theta_shapes,theta_constraints
def log_prior_fun(params,mup=0,sigmap=1):
    return norm.logpdf(params["mu"],mup,sigmap)

def log_lik_fun(params,data,sigma=1): #Assume sicma knowm
    return jnp.sum(norm.logpdf(data, params["mu"], sigma))

data= jnp.array([1.])
M = 100_000
opt_method = "L-BFGS-B"
verbose = True

for guide in [LowRankGuide(1),FullRankGuide(),MeanFieldGuide(),LaplaceApproxGuide(),MAPGuide()][:]:
#guide = FullRankGuide()  # OptiGuide , FullRankGui

        # Initialize optimizer

    curried_lik = jax.jit(partial(log_lik_fun, data=data))

    curried_prior  = log_prior_fun

    theta_shapes,theta_constraints = generate_shapes_and_constraints()

    result = optimize_advi(
        theta_shapes,
        log_prior_fun=curried_prior,
        log_lik_fun=curried_lik,
        constrain_fun_dict=theta_constraints,
        verbose=False,
        guide=guide, 
        M=M ,  # Number of MC samples
        opt_method=opt_method,
        n_draws=1_000_000
        #M=10
    )
    posterior_mu_sample = jnp.mean(result["draws"]["mu"])
    posterior_var_sample = jnp.std(result["draws"]["mu"])**2
    elbo = result["elbo"]
    #print(result["guide"])
    print(f"{guide.name}:")
    print(f"mu {posterior_mu_sample:.2e} , var {posterior_var_sample:.2e} , elbo {elbo:.2e}")


LowRankGuide:
mu 4.97e-01 , var 5.04e-01 , elbo -1.51e+00
FullRankGuide:
mu 5.00e-01 , var 4.99e-01 , elbo -1.52e+00
MeanFieldGuide:
mu 5.00e-01 , var 4.99e-01 , elbo -1.52e+00
MAPGuide:
mu 5.00e-01 , var 0.00e+00 , elbo -2.09e+00
LaplaceApproxGuide:
mu 5.00e-01 , var 5.00e-01 , elbo -1.34e+00
