Based on https://bambinos.github.io/bambi/notebooks/multi-level_regression.html

Those notebook uses [formulae](https://github.com/bambinos/formulae) which implement the [Wilkinson's formulas for mixed-effects models](https://de.mathworks.com/help/stats/wilkinson-notation.html)

In [None]:
import arviz as az
import bambi as bmb
import numpy as np
import matplotlib.pyplot as plt
import statsmodels.api as sm
import xarray as xr

In [None]:
az.style.use("arviz-darkgrid")
SEED = 7355608

In [None]:
# Load up data from statsmodels
data = sm.datasets.get_rdataset("dietox", "geepack").data
data.describe()

In [None]:
data.head()

In [None]:
model = bmb.Model("Weight ~ Time + (Time|Pig)", data)
results = model.fit()

In [None]:
model

In [None]:
model.plot_priors();

In [None]:
# Plot posteriors
az.plot_trace(
    results,
    var_names=["Intercept", "Time", "1|Pig", "Time|Pig", "sigma"],
    compact=True,
);

In [None]:
az.summary(
    results, var_names=["Intercept", "Time", "1|Pig_sigma", "Time|Pig_sigma", "sigma"]
)

In [None]:
# The ID of the first pig is '4601'
data_0 = data[data["Pig"] == 4601][["Time", "Weight"]]
time = np.array([1, 12])

posterior = az.extract_dataset(results)
intercept_common = posterior["Intercept"]
slope_common = posterior["Time"]

intercept_specific_0 = posterior["1|Pig"].sel(Pig__factor_dim="4601")
slope_specific_0 = posterior["Time|Pig"].sel(Pig__factor_dim="4601")

a = intercept_common + intercept_specific_0
b = slope_common + slope_specific_0

# make time a DataArray so we can get automatic broadcasting
time_xi = xr.DataArray(time)
plt.plot(time_xi, (a + b * time_xi).T, color="C1", lw=0.1)
plt.plot(time_xi, a.mean() + b.mean() * time_xi, color="black")
plt.scatter(data_0["Time"], data_0["Weight"], zorder=2)
plt.ylabel("Weight (kg)")
plt.xlabel("Time (weeks)");

In [None]:
intercept_group_specific = posterior["1|Pig"]
slope_group_specific = posterior["Time|Pig"]
a = intercept_common.mean() + intercept_group_specific.mean("sample")
b = slope_common.mean() + slope_group_specific.mean("sample")
time_xi = xr.DataArray(time)
plt.plot(time_xi, (a + b * time_xi).T, color="C1", alpha=0.7, lw=0.8)
plt.ylabel("Weight (kg)")
plt.xlabel("Time (weeks)");

In [None]:
az.plot_forest(
    results,
    var_names=["Intercept", "Time"],
    figsize=(8, 2),
);

In [None]:
az.plot_posterior(results, var_names=["Intercept", "Time"], ref_val=0, rope=[-1, 1]);