Based on https://bambinos.github.io/bambi/notebooks/model_comparison.html

In [None]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import numpy as np
import pandas as pd
import seaborn as sns
import warnings

from scipy.special import expit as invlogit

In [None]:
# Disable a FutureWarning in ArviZ at the moment of running the notebook
az.style.use("arviz-darkgrid")
warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
data = bmb.load_data("adults")

In [None]:
data.info()
data.head()

In [None]:
categorical_cols = data.columns[data.dtypes == object].tolist()
for col in categorical_cols:
    data[col] = data[col].astype("category")
data.info()

In [None]:
# Just a utilitary function to truncate labels and avoid overlapping in plots
def truncate_labels(ticklabels, width=8):
    def truncate(label, width):
        if len(label) > width - 3:
            return label[0 : (width - 4)] + "..."
        else:
            return label

    labels = [x.get_text() for x in ticklabels]
    labels = [truncate(lbl, width) for lbl in labels]

    return labels

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(12, 15))
sns.countplot(x="income", color="C0", data=data, ax=axes[0, 0], saturation=1)
sns.countplot(x="sex", color="C0", data=data, ax=axes[0, 1], saturation=1)
sns.countplot(x="race", color="C0", data=data, ax=axes[1, 0], saturation=1)
axes[1, 0].set_xticklabels(truncate_labels(axes[1, 0].get_xticklabels()))
axes[1, 1].hist(data["age"], bins=20)
axes[1, 1].set_xlabel("Age")
axes[1, 1].set_ylabel("Count")
axes[2, 0].hist(data["hs_week"], bins=20)
axes[2, 0].set_xlabel("Hours of work / week")
axes[2, 0].set_ylabel("Count")
axes[2, 1].axis("off");

In [None]:
data = data[data["race"].isin(["Black", "White"])]
data["race"] = data["race"].cat.remove_unused_categories()
age_bins = [17, 25, 35, 45, 65, 90]
data["age_binned"] = pd.cut(data["age"], age_bins)
hours_bins = [0, 20, 40, 60, 100]
data["hs_week_binned"] = pd.cut(data["hs_week"], hours_bins)

In [None]:
fig, axes = plt.subplots(3, 2, figsize=(12, 15))
sns.countplot(x="income", color="C0", data=data, ax=axes[0, 0])
sns.countplot(x="sex", hue="income", data=data, ax=axes[0, 1])
sns.countplot(x="race", hue="income", data=data, ax=axes[1, 0])
sns.countplot(x="age_binned", hue="income", data=data, ax=axes[1, 1])
sns.countplot(x="hs_week_binned", hue="income", data=data, ax=axes[2, 0])
axes[2, 1].axis("off");

In [None]:
age_mean = np.mean(data["age"])
age_std = np.std(data["age"])
hs_mean = np.mean(data["hs_week"])
hs_std = np.std(data["hs_week"])

data["age"] = (data["age"] - age_mean) / age_std
data["age2"] = data["age"] ** 2
data["age3"] = data["age"] ** 3
data["hs_week"] = (data["hs_week"] - hs_mean) / hs_std
data["hs_week2"] = data["hs_week"] ** 2
data["hs_week3"] = data["hs_week"] ** 3

data = data.drop(columns=["age_binned", "hs_week_binned"])

In [None]:
model1 = bmb.Model(
    "income['>50K'] ~ sex + race + age + hs_week",
    data.sample(n=200, replace=False),
    family="bernoulli",
)
try:
    fitted1 = model1.fit(draws=1000, idata_kwargs={"log_likelihood": True})
except EOFError:
    pass

model 1 raises `EOFError` if the full dataset is used

In [None]:
az.plot_trace(fitted1)
az.summary(fitted1)

In [None]:
model2 = bmb.Model(
    "income['>50K'] ~ sex + race + age + age2 + hs_week + hs_week2",
    data.sample(n=200, replace=False),
    family="bernoulli",
)
try:
    fitted2 = model2.fit(idata_kwargs={"log_likelihood": True})
except EOFError:
    pass

model 2 raises `EOFError` if the full dataset is used

In [None]:
az.plot_trace(fitted2)
az.summary(fitted2)

In [None]:
model3 = bmb.Model(
    "income['>50K'] ~ age + age2 + age3 + hs_week + hs_week2 + hs_week3 + sex + race",
    data.sample(n=200, replace=False),
    family="bernoulli",
)
try:
    fitted3 = model3.fit(
        draws=1000,
        random_seed=1234,
        target_accept=0.9,
        idata_kwargs={"log_likelihood": True},
    )
except EOFError:
    pass

model 3 raises `EOFError` if the full dataset is used

In [None]:
az.plot_trace(fitted3)
az.summary(fitted3)

In [None]:
models_dict = {"model1": fitted1, "model2": fitted2, "model3": fitted3}
df_compare = az.compare(models_dict)
df_compare

In [None]:
az.plot_compare(df_compare, insample_dev=False);

funnily model2 wins out here instead of model3 as in the docs, likely due to the reduced dataset

In [None]:
HS_WEEK = (40 - hs_mean) / hs_std
AGE = (np.linspace(18, 75) - age_mean) / age_std

fig, ax = plt.subplots()
handles = []
i = 0

for race in ["Black", "White"]:
    for sex in ["Female", "Male"]:
        color = f"C{i}"
        label = f"{race} - {sex}"
        handles.append(mlines.Line2D([], [], color=color, label=label, lw=3))

        new_data = pd.DataFrame(
            {
                "sex": [sex] * len(AGE),
                "race": [race] * len(AGE),
                "age": AGE,
                "age2": AGE**2,
                "hs_week": [HS_WEEK] * len(AGE),
                "hs_week2": [HS_WEEK**2] * len(AGE),
            }
        )
        new_idata = model2.predict(fitted2, data=new_data, inplace=False)
        mean = new_idata.posterior["p"].values

        az.plot_hdi(AGE * age_std + age_mean, mean, ax=ax, color=color)
        az.plot_hdi(AGE * age_std + age_mean, mean, ax=ax, color=color, hdi_prob=0.5)
        i += 1

ax.set_xlabel("Age")
ax.set_ylabel("P(Income > $50K)")
ax.legend(handles=handles, loc="upper left");