# GLM: Mini-batch ADVI on hierarchical regression model

:::{post} Sept 23, 2021
:tags: generalized linear model, hierarchical model, variational inference
:category: intermediate
:::

Unlike Gaussian mixture models, (hierarchical) regression models have independent variables. These variables affect the likelihood function, but are not random variables. When using mini-batch, we should take care of that.

In [None]:
%env PYTENSOR_FLAGS=device=cpu, floatX=float32, warn_float64=ignore

import os

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pytensor
import pytensor.tensor as pt
import seaborn as sns

from scipy import stats

print(f"Running on PyMC v{pm.__version__}")

In [None]:
%config InlineBackend.figure_format = 'retina'
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

In [None]:
try:
    data = pd.read_csv(os.path.join("..", "data", "radon.csv"))
except FileNotFoundError:
    data = pd.read_csv(pm.get_data("radon.csv"))

data

In [None]:
county_idx = data["county_code"].values
floor_idx = data["floor"].values
log_radon_idx = data["log_radon"].values

coords = {"counties": data.county.unique()}

Here, `log_radon_idx_t` is a dependent variable, while `floor_idx_t` and `county_idx_t` determine the independent variable.

In [None]:
log_radon_idx_t = pm.Minibatch(log_radon_idx, batch_size=100)
floor_idx_t = pm.Minibatch(floor_idx, batch_size=100)
county_idx_t = pm.Minibatch(county_idx, batch_size=100)

In [None]:
with pm.Model(coords=coords) as hierarchical_model:
    # Hyperpriors for group nodes
    mu_a = pm.Normal("mu_alpha", mu=0.0, sigma=100**2)
    sigma_a = pm.Uniform("sigma_alpha", lower=0, upper=100)
    mu_b = pm.Normal("mu_beta", mu=0.0, sigma=100**2)
    sigma_b = pm.Uniform("sigma_beta", lower=0, upper=100)

Intercept for each county, distributed around group mean `mu_a`. Above we just set `mu` and `sd` to a fixed value while here we plug in a common group distribution for all `a` and `b` (which are vectors with the same length as the number of unique counties in our example).

In [None]:
with hierarchical_model:
    a = pm.Normal("alpha", mu=mu_a, sigma=sigma_a, dims="counties")
    # Intercept for each county, distributed around group mean mu_a
    b = pm.Normal("beta", mu=mu_b, sigma=sigma_b, dims="counties")

Model prediction of radon level `a[county_idx]` translates to `a[0, 0, 0, 1, 1, ...]`, we thus link multiple household measures of a county to its coefficients.

In [None]:
with hierarchical_model:
    radon_est = a[county_idx_t] + b[county_idx_t] * floor_idx_t

Finally, we specify the likelihood:

In [None]:
with hierarchical_model:
    # Model error
    eps = pm.Uniform("eps", lower=0, upper=100)

    # Data likelihood
    radon_like = pm.Normal(
        "radon_like", mu=radon_est, sigma=eps, observed=log_radon_idx_t, total_size=len(data)
    )

Random variables `radon_like`, associated with `log_radon_idx_t`, should be given to the function for ADVI to denote that as observations in the likelihood term.

On the other hand, `minibatches` should include the three variables above.

Then, run ADVI with mini-batch.

In [None]:
with hierarchical_model:
    approx = pm.fit(100_000, callbacks=[pm.callbacks.CheckParametersConvergence(tolerance=1e-4)])

idata_advi = approx.sample(500)

Check the trace of ELBO and compare the result with MCMC.

In [None]:
plt.plot(approx.hist);

We can extract the covariance matrix from the mean field approximation and use it as the scaling matrix for the NUTS algorithm.

In [None]:
scaling = approx.cov.eval()

Also, we can generate samples (one for each chain) to use as the starting points for the sampler.

In [None]:
n_chains = 4
sample = approx.sample(return_inferencedata=False, size=n_chains)
start_dict = list(sample[i] for i in range(n_chains))

In [None]:
# Inference button (TM)!
with pm.Model(coords=coords):
    mu_a = pm.Normal("mu_alpha", mu=0.0, sigma=100**2)
    sigma_a = pm.Uniform("sigma_alpha", lower=0, upper=100)
    mu_b = pm.Normal("mu_beta", mu=0.0, sigma=100**2)
    sigma_b = pm.Uniform("sigma_beta", lower=0, upper=100)

    a = pm.Normal("alpha", mu=mu_a, sigma=sigma_a, dims="counties")
    b = pm.Normal("beta", mu=mu_b, sigma=sigma_b, dims="counties")

    # Model error
    eps = pm.Uniform("eps", lower=0, upper=100)

    radon_est = a[county_idx] + b[county_idx] * floor_idx

    radon_like = pm.Normal("radon_like", mu=radon_est, sigma=eps, observed=log_radon_idx)

    # essentially, this is what init='advi' does
    step = pm.NUTS(scaling=scaling, is_cov=True)
    hierarchical_trace = pm.sample(draws=2000, step=step, chains=n_chains, initvals=start_dict)

In [None]:
az.plot_density(
    data=[idata_advi, hierarchical_trace],
    var_names=["~alpha", "~beta"],
    data_labels=["ADVI", "NUTS"],
);

## Watermark

In [None]:
%load_ext watermark
%watermark -n -u -v -iv -w -p xarray

:::{include} ../page_footer.md
:::