# Session 4: The Bayesian Workflow

In this session we apply the complete Bayesian workflow to COVID-19 case data. We'll iterate through model building, prior predictive checks, fitting, diagnostics, posterior predictive checks, model comparison, and forecasting.

Learning objectives:
- Apply prior predictive checks to validate model/priors
- Fit models and assess convergence (R-hat, ESS, divergences/energy)
- Evaluate fit with posterior predictive checks and residuals
- Compare models (LOO/WAIC) and refine iteratively
- Use pm.Data for forecasting and scenario analysis


## MCMC Output Processing and Model Checking with ArviZ

ArviZ is a library for exploratory analysis of Bayesian models. In the workflow, we use it to:

- Inspect sampler behavior (e.g., step size, tree depth, divergences)
- Assess convergence (R-hat close to 1, effective sample sizes sufficiently large)
- Visualize posterior distributions and dependencies
- Evaluate fit using posterior predictive checks (PPCs)
- Compare models using LOO/WAIC

For workflow credibility, diagnostics must come before interpretation. A model that has not converged or explores the posterior inefficiently cannot be trusted, regardless of apparent fit.


In [None]:
import load_covid_data
import pymc as pm
import arviz as az
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import warnings
warnings.simplefilter("ignore")

sns.set_context('talk')

sampler_kwargs = {"chains": 4, "cores": 4, "tune": 2000}
RANDOM_SEED = 20090425


## Load data

We will use COVID-19 case counts, aligned to the day each country crosses 100 confirmed cases, to stabilize early reporting noise and enable comparisons.


In [None]:
df = load_covid_data.load_data(drop_states=True, filter_n_days_100=2)
countries = df.country.unique()
n_countries = len(countries)
# Align to day since crossing 100 confirmed cases
df = df.loc[lambda x: (x.days_since_100 >= 0)]
df.head()


## Bayesian workflow steps

1. Plot the data
2. Build model
3. Run prior predictive check
4. Fit model
5. Assess convergence
6. Run posterior predictive check
7. Improve model

### 1) Plot the data (Germany, first 30 days)


In [None]:
country = 'Germany'
date = '2020-07-31'
df_country = df.query(f'country=="{country}"').loc[:date].iloc[:30]

fig, ax = plt.subplots(figsize=(10, 8))
df_country.confirmed.plot(ax=ax)
ax.set(ylabel='Confirmed cases', title=country)
sns.despine()


### 2) Build an initial model (exponential with Normal likelihood)

We start with a simple exponential growth model with a Normal likelihood to illustrate prior predictive checks and why this is inadequate for counts.


In [None]:
# Time and observations
t = df_country.days_since_100.values
confirmed = df_country.confirmed.values

with pm.Model() as model_exp1:
    a = pm.Normal('a', mu=0, sigma=100)
    b = pm.Normal('b', mu=0.3, sigma=0.3)
    growth = a * (1 + b) ** t
    eps = pm.HalfNormal('eps', 100)
    pm.Normal('obs', mu=growth, sigma=eps, observed=confirmed)


### 3) Prior predictive check

Generate data from the prior to check that implied data are plausible.


In [None]:
with model_exp1:
    prior_pred = pm.sample_prior_predictive(random_seed=RANDOM_SEED)

fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(prior_pred.prior_predictive['obs'].values.squeeze().T, color="0.5", alpha=.1)
ax.set(ylim=(-1000, 1000), xlim=(0, 10), title="Prior predictive", xlabel="Days since 100 cases", ylabel="Positive cases");


### 2b) Improve likelihood: Negative Binomial for counts

The Normal likelihood allows negative counts and mismodels dispersion. We switch to a Negative Binomial with an overdispersion parameter `alpha`.


In [None]:
t = df_country.days_since_100.values
confirmed = df_country.confirmed.values

with pm.Model() as model_exp2:
    a = pm.Normal('a', mu=100, sigma=25)
    b = pm.Normal('b', mu=0.3, sigma=0.1)
    growth = a * (1 + b) ** t
    alpha = pm.Gamma("alpha", mu=6, sigma=1)
    pm.NegativeBinomial('obs', growth, alpha=alpha, observed=confirmed)

with model_exp2:
    prior_pred2 = pm.sample_prior_predictive(random_seed=RANDOM_SEED)

fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(prior_pred2.prior_predictive['obs'].values.squeeze().T, color="0.5", alpha=.1)
ax.set(ylim=(-100, 1000), xlim=(0, 10), title="Prior predictive (NB)", xlabel="Days since 100 cases", ylabel="Positive cases");


### 4) Fit model and 5) Assess convergence


In [None]:
with model_exp2:
    trace_exp2 = pm.sample(**sampler_kwargs, random_seed=RANDOM_SEED)

az.plot_trace(trace_exp2, var_names=['a', 'b', 'alpha']);
plt.tight_layout();

az.summary(trace_exp2, var_names=['a', 'b', 'alpha'])

az.plot_energy(trace_exp2);


### 6) Posterior predictive check and residuals


In [None]:
with model_exp2:
    post_pred = pm.sample_posterior_predictive(trace_exp2.posterior)

fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(post_pred.posterior_predictive['obs'].sel(chain=0).values.squeeze().T, color='0.5', alpha=.05)
ax.plot(confirmed, color='r', label='data')
ax.set(xlabel="Days since 100 cases", ylabel="Confirmed cases (log scale)", title=country, yscale="log")
ax.legend();

# Residuals
fig, ax = plt.subplots(figsize=(10, 8))
resid = post_pred.posterior_predictive["obs"].sel(chain=0) - confirmed
ax.plot(resid.T, color="0.5", alpha=.01)
ax.set(ylim=(-50_000, 200_000), ylabel="Residual", xlabel="Days since 100 cases");


### 7) Improve model with constrained priors (exp3) and compare to exp2

Reflect prior knowledge: intercept ≥ 100 (because we start at 100 cases) and positive growth. Use a constrained prior for `alpha`.


In [None]:
t = df_country.days_since_100.values
confirmed = df_country.confirmed.values

with pm.Model() as model_exp3:
    a0 = pm.HalfNormal('a0', sigma=25)
    a = pm.Deterministic('a', a0 + 100)
    b = pm.HalfNormal('b', sigma=0.2)
    growth = a * (1 + b) ** t
    alpha = pm.Gamma("alpha", mu=6, sigma=1)
    pm.NegativeBinomial('obs', growth, alpha=alpha, observed=confirmed)

with model_exp3:
    prior_pred3 = pm.sample_prior_predictive(random_seed=RANDOM_SEED)
    trace_exp3 = pm.sample(**sampler_kwargs, random_seed=RANDOM_SEED)

az.plot_trace(trace_exp3, var_names=['a', 'b', 'alpha']);
plt.tight_layout();

# Compute log-likelihoods for comparison
with model_exp2:
    pm.compute_log_likelihood(trace_exp2)
with model_exp3:
    pm.compute_log_likelihood(trace_exp3)

cmp = az.compare({"exp2": trace_exp2, "exp3": trace_exp3})
az.plot_compare(cmp);


### Forecasting with pm.Data

Switch to `pm.Data` containers for time and case series to support future-time posterior predictive forecasting.


In [None]:
with pm.Model() as model_exp4:
    t_data = pm.Data('t', df_country.days_since_100.values)
    confirmed_data = pm.Data('confirmed', df_country.confirmed.values)

    a0 = pm.HalfNormal('a0', sigma=25)
    a = pm.Deterministic('a', a0 + 100)
    b = pm.HalfNormal('b', sigma=0.2)
    growth = a * (1 + b) ** t_data

    pm.NegativeBinomial('obs', growth, alpha=pm.Gamma("alpha", mu=6, sigma=1), observed=confirmed_data)

    trace_exp4 = pm.sample(**sampler_kwargs, random_seed=RANDOM_SEED)

# Forecast next 60 days
forecast_days = 60
future_days = np.arange(len(df_country.days_since_100.values), len(df_country.days_since_100.values) + forecast_days)

with model_exp4:
    pm.set_data({'t': np.concatenate([df_country.days_since_100.values, future_days]),
                 'confirmed': np.concatenate([df_country.confirmed.values, np.zeros(forecast_days, dtype='int')])})
    forecast = pm.sample_posterior_predictive(trace_exp4.posterior)


In [None]:
historical_days = len(df_country.days_since_100.values)
all_days = np.arange(historical_days + forecast_days)

forecast_samples = forecast.posterior_predictive['obs'].values
forecast_mean = forecast_samples.mean(axis=(0, 1))
forecast_lower = np.percentile(forecast_samples, 2.5, axis=(0, 1))
forecast_upper = np.percentile(forecast_samples, 97.5, axis=(0, 1))

fig, ax = plt.subplots(figsize=(12, 8))
ax.plot(np.arange(historical_days), df_country.confirmed.values, label='Observed', color='black')
ax.plot(all_days, forecast_mean, label='Mean forecast', color='blue')
ax.fill_between(all_days, forecast_lower, forecast_upper, color='blue', alpha=0.2, label='95% CI')
ax.axvline(historical_days-1, color='gray', linestyle='--')
ax.set(title=f'COVID-19 Forecast for {country}', xlabel='Days since 100 cases', ylabel='Confirmed cases')
ax.legend();


### Logistic growth model and comparison

The exponential model cannot capture plateauing behavior. We fit a logistic growth model with a carrying capacity and compare.


In [None]:
df_country_full = df.query(f'country=="{country}"').loc[:date]

with pm.Model() as logistic_model:
    t_data = pm.Data('t', df_country_full.days_since_100.values)
    confirmed_data = pm.Data('confirmed', df_country_full.confirmed.values)

    a0 = pm.HalfNormal('a0', sigma=25)
    intercept = pm.Deterministic('intercept', a0 + 100)
    b = pm.HalfNormal('b', sigma=0.2)
    carrying_capacity = pm.Uniform('carrying_capacity', lower=1_000, upper=80_000_000)
    a = carrying_capacity / intercept - 1
    growth = carrying_capacity / (1 + a * pm.math.exp(-b * t_data))

    pm.NegativeBinomial('obs', growth, alpha=pm.Gamma("alpha", mu=6, sigma=1), observed=confirmed_data)

    trace_logistic = pm.sample(**sampler_kwargs, target_accept=0.9, random_seed=RANDOM_SEED)
    pm.sample_posterior_predictive(trace_logistic, extend_inferencedata=True)

az.plot_trace(trace_logistic);
plt.tight_layout();

with model_exp4:
    pm.set_data({"t": df_country_full.days_since_100.values, "confirmed": df_country_full.confirmed.values})
    trace_exp4_full = pm.sample(**sampler_kwargs, random_seed=RANDOM_SEED)

with model_exp4:
    pm.compute_log_likelihood(trace_exp4_full)
with logistic_model:
    pm.compute_log_likelihood(trace_logistic)

az.plot_compare(az.compare({"exp4": trace_exp4_full, "logistic": trace_logistic}))


### Validation on another country

Fit the logistic model to a different country (e.g., US) to probe assumptions and generalization.


In [None]:
country2 = 'US'
df_country2 = df.query(f'country=="{country2}"').loc[:date]

with pm.Model() as logistic_model2:
    t_data = pm.Data('t', df_country2.days_since_100.values)
    confirmed_data = pm.Data('confirmed', df_country2.confirmed.values)

    a0 = pm.HalfNormal('a0', sigma=25)
    intercept = pm.Deterministic('intercept', a0 + 100)
    b = pm.HalfNormal('b', sigma=0.2)
    carrying_capacity = pm.Uniform('carrying_capacity', lower=1_000, upper=100_000_000)
    a = carrying_capacity / intercept - 1
    growth = carrying_capacity / (1 + a * pm.math.exp(-b * t_data))

    pm.NegativeBinomial('obs', growth, alpha=pm.Gamma("alpha", mu=6, sigma=1), observed=confirmed_data)

    trace_logistic_us = pm.sample(**sampler_kwargs, random_seed=RANDOM_SEED)
    pm.sample_posterior_predictive(trace_logistic_us, extend_inferencedata=True)

az.plot_trace(trace_logistic_us); plt.tight_layout();

fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(trace_logistic_us.posterior_predictive['obs'].sel(chain=0).squeeze().values.T, color='0.5', alpha=.05)
ax.plot(df_country2.confirmed.values, color='r')
ax.set(xlabel='Days since 100 cases', ylabel='Confirmed cases', title=country2);


### Optional: Calibration and sensitivity

- LOO-PIT (leave-one-out probability integral transform) for calibration
- Sensitivity to prior choices for `alpha` and slope `b`


In [None]:
loo_exp4 = az.loo(trace_exp4_full)
loo_logistic = az.loo(trace_logistic)
az.plot_loo_pit(idata=trace_logistic, y='obs');


## References

- Gabry, J., Simpson, D., Vehtari, A., Betancourt, M., & Gelman, A. (2019). Visualization in Bayesian workflow. JRSS-A, 182(2), 389–402.
- Gelman, A., Hwang, J., & Vehtari, A. (2014). Understanding predictive information criteria for Bayesian models. Statistics and Computing, 24(6), 997–1016.
- Vehtari, A., Gelman, A., Simpson, D., Carpenter, B., & Bürkner, P.-C. (2019). Rank-normalization, folding, and localization: An improved R-hat. arXiv:1903.08008.


### Sampler statistics and tuning behavior

Sampler statistics expose what HMC/NUTS is doing under the hood. We examine trees (depth), acceptance rates, and energy behavior to detect pathologies. Consistently hitting maximum tree depth, low acceptance, or poor energy geometry suggest re-tuning or reparameterization.


In [None]:
# Inspect sampler stats from the latest fit (e.g., NB exp2)
trace = trace_exp2
trace.sample_stats

# Tree depth by chain
trace.sample_stats["tree_depth"].plot(col="chain", ls="none", marker=".", alpha=0.3);

# Acceptance rate histogram
trace.sample_stats["acceptance_rate"].plot.hist(bins=20, density=True);


## Posterior analysis and interpretation

After confirming sampler health, we interpret posteriors: marginal distributions, correlation structure, credible intervals, and effect sizes. Summaries should be coupled with plots to guard against overreliance on point estimates.


In [None]:
az.summary(trace_exp2, var_names=['a', 'b', 'alpha'], round_to=2)
az.plot_posterior(trace_exp2, var_names=['a', 'b', 'alpha'], kind='hist'); plt.tight_layout();
az.plot_pair(trace_exp2, var_names=['a', 'b'], kind='kde'); plt.tight_layout();


## Enhanced posterior predictive checks

Posterior predictive checks evaluate whether replicated data from the model resemble the observed data. We combine multiple perspectives: histograms, cumulative distribution comparisons, and ECDF overlays. Visual agreement indicates good fit, while systematic deviations point to model misspecification.


In [None]:
with model_exp2:
    post_pred2 = pm.sample_posterior_predictive(trace_exp2.posterior)

# Histogram and KDE overlays
az.plot_ppc(post_pred2, data_pairs={"obs": "obs"});
# Cumulative PPC to assess tail behavior
az.plot_ppc(post_pred2, kind='cumulative', mean=False);


## Diagnostics: Divergences, BFMI, and Energy

Divergences flag integration failures in HMC that often arise from problematic geometries. BFMI quantifies how well momentum resampling matches the marginal energy distribution. Poor overlap between the marginal energy and energy transitions indicates inefficient exploration.

We will:
- Check for divergences and their locations
- Compute BFMI and visualize energy overlap
- Discuss remediation: increase `target_accept`, reparameterize, rescale predictors/responses

These diagnostics and their interpretations are adapted from the existing model checking notebook.


## Diagnostics: Divergences, BFMI, and Energy

As we have seen, Hamiltonian Monte Carlo (and NUTS) performs numerical integration to explore the posterior distribution. When the integration goes wrong, it can go dramatically wrong. Divergences signal numerical integration failures in regions of difficult geometry and must be investigated.

![diverging HMC](images/diverging_hmc.png)

Two practical remedies:
1. Increase the target acceptance rate (e.g., `target_accept=0.9`), which typically reduces the step size and improves integration accuracy.
2. Reparameterize the model to improve geometry (e.g., non-centered parameterizations, appropriate scaling of predictors/responses).

The Bayesian Fraction of Missing Information (BFMI) quantifies how well momentum resampling matches the marginal energy distribution. Poor overlap between the marginal energy and the distribution of energy transitions indicates inefficient exploration.


In [None]:
# BFMI and energy plots for recent fits
az.bfmi(trace_exp2)
az.plot_energy(trace_exp2);


## Model evaluation and comparison (LOO/WAIC)

Information criteria approximate out-of-sample predictive performance by penalizing model complexity. With `arviz.compare`, we obtain LOO/WAIC, standard errors, and model weights. Interpret differences with uncertainty: small deltas relative to SE suggest caution. Use weights as soft evidence, not as decisive selection rules.


In [None]:
cmp2 = az.compare({"exp4": trace_exp4_full, "logistic": trace_logistic})
cmp2
az.plot_compare(cmp2);


## Best practices for model improvement

- Prior predictive checks: validate that prior + likelihood can generate plausible data
- Center/scale predictors; consider link functions matching the outcome's domain
- Use non-centered parameterizations in hierarchical or weakly-informed scales
- Diagnose and remediate: increase `target_accept`, revisit priors, reparameterize
- Compare candidate models with LOO; prefer simpler models unless predictive gains are clear
- Validate on held-out data or out-of-sample regions (e.g., different country/time)
- Communicate uncertainty with intervals, posterior predictive envelopes, and scenario analyses


## Optional: Hierarchical COVID extension

Counts from multiple countries can be modeled with partial pooling to share information while allowing country-specific variation. Start with varying intercepts and consider varying slopes, using a log link to ensure positivity and a Negative Binomial likelihood. Diagnostics and PPCs proceed as before, with special attention to non-centered parameterizations.


In [None]:
subset_countries = ['Germany', 'US', 'Italy', 'Spain']
df_sub = df[df.country.isin(subset_countries)].copy()
country_idx, countries_unique = pd.factorize(df_sub.country)

t = df_sub.days_since_100.values
y = df_sub.confirmed.values
n = len(y)
C = len(countries_unique)

with pm.Model(coords={"country": countries_unique}) as hier_model:
    a_group = pm.Normal('a_group', 0.0, 10.0)
    b_group = pm.HalfNormal('b_group', 0.5)

    a_raw = pm.Normal('a_raw', 0.0, 1.0, dims="country")
    b_raw = pm.Normal('b_raw', 0.0, 1.0, dims="country")

    sigma_a = pm.HalfNormal('sigma_a', 10.0)
    sigma_b = pm.HalfNormal('sigma_b', 0.5)

    a = pm.Deterministic('a', 100 + a_group + sigma_a * a_raw, dims="country")
    b = pm.Deterministic('b', pm.math.abs(b_group + sigma_b * b_raw), dims="country")

    mu = a[country_idx] * (1 + b[country_idx]) ** t

    alpha = pm.Gamma('alpha', mu=6, sigma=1)
    pm.NegativeBinomial('obs', mu, alpha=alpha, observed=y)

    hier_trace = pm.sample(tune=1500, chains=4, cores=4, random_seed=RANDOM_SEED, target_accept=0.9)

az.plot_trace(hier_trace, var_names=['a_group', 'b_group', 'sigma_a', 'sigma_b']); plt.tight_layout();
az.summary(hier_trace, var_names=['a_group', 'b_group', 'sigma_a', 'sigma_b'])
az.plot_energy(hier_trace);


In [None]:
%load_ext watermark
%watermark -n -u -v -iv -w
