In [None]:
import matplotlib.pyplot as plt

# This makes sure we use double precision
from jax.config import config; config.update("jax_enable_x64", True)

In [None]:
import pymc as pm
import pytensor

# This is a good example model since it's built into PyMC
def fetch_radon_model():

    data = pd.read_csv(pm.get_data("radon.csv"))
    data["log_radon"] = data["log_radon"].astype(pytensor.config.floatX)
    county_names = data.county.unique()
    county_idx = data.county_code.values.astype("int32")

    n_counties = len(data.county.unique())

    with pm.Model() as hierarchical_model:
        # Hyperpriors for group nodes
        mu_a = pm.Normal("mu_a", mu=0.0, sigma=100.0)
        sigma_a = pm.HalfNormal("sigma_a", 5.0)
        mu_b = pm.Normal("mu_b", mu=0.0, sigma=100.0)
        sigma_b = pm.HalfNormal("sigma_b", 5.0)

        # Intercept for each county, distributed around group mean mu_a
        # Above we just set mu and sd to a fixed value while here we
        # plug in a common group distribution for all a and b (which are
        # vectors of length n_counties).
        a = pm.Normal("a", mu=mu_a, sigma=sigma_a, shape=n_counties)
        # Intercept for each county, distributed around group mean mu_a
        b = pm.Normal("b", mu=mu_b, sigma=sigma_b, shape=n_counties)

        # Model error
        eps = pm.HalfCauchy("eps", 5.0)

        radon_est = a[county_idx] + b[county_idx] * data.floor.values

        # Data likelihood
        radon_like = pm.Normal(
            "radon_like", mu=radon_est, sigma=eps, observed=data.log_radon
        )

    return hierarchical_model

In [None]:
# Define your PyMC model as usual:
# Here we could use any PyMC model instead!
m = fetch_radon_model()

In [None]:
# This is the most convenient function to run
from dadvi.pymc.jax_api import fit_pymc_dadvi_with_jax

model_result = fit_pymc_dadvi_with_jax(m, num_fixed_draws=30)

In [None]:
# If we do this, we get draws from the mean field estimate
mean_field_draws = model_result.get_posterior_draws_mean_field()

In [None]:
# We could also look at the mean field means:
model_result.get_posterior_means()

In [None]:
# Or the mean field sds:
model_result.get_posterior_standard_deviations_mean_field()

In [None]:
# We can compare the results against NUTS:
with m as model:
    nuts_res = pm.sample()

In [None]:
example_variable = nuts_res.posterior['b'].values

In [None]:
# Flatten NUTS into a vector
reshaped = example_variable.reshape(4000, -1)

In [None]:
# Compare against mean field
plt.hist(reshaped[:, 0], density=True, alpha=0.5)
plt.hist(mean_field_draws['b'][:, 0], density=True, alpha=0.5)

In [None]:
# LRVB correction
# Here we use the identity function, but something else works too via the delta method.
lrvb_corrected = model_result.get_frequentist_sd_and_lrvb_correction_of_scalar_valued_function(lambda x: x['b'][0])

In [None]:
# This also shows the frequentist SD estimate
lrvb_corrected

In [None]:
# lrvb_sd is the LRVB estimate of the posterior SD
# freq_sd is the LRVB estimate of the SD due to the fixed draws
# n_hvp_calls is the number of hvp calls that were required to compute this