Based on https://bambinos.github.io/bambi/notebooks/logistic_regression.html

In [None]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

In [None]:
az.style.use("arviz-darkgrid")
SEED = 7355608

In [None]:
data = bmb.load_data("ANES")
data.head()

In [None]:
data["vote"].value_counts()

In [None]:
data["party_id"].value_counts()

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(10, 4), sharey=True, constrained_layout=True)
key = dict(zip(data["party_id"].unique(), range(3)))
for label, df in data.groupby("party_id"):
    ax[key[label]].hist(df["age"])
    ax[key[label]].set_xlim([18, 90])
    ax[key[label]].set_xlabel("Age")
    ax[key[label]].set_ylabel("Frequency")
    ax[key[label]].set_title(label)
    ax[key[label]].axvline(df["age"].mean(), color="C1")

In [None]:
pd.crosstab(data["vote"], data["party_id"])

In [None]:
clinton_data = data.loc[data["vote"].isin(["clinton", "trump"]), :]
clinton_data.head()

In [None]:
clinton_model = bmb.Model(
    "vote['clinton'] ~ party_id + party_id:age", clinton_data, family="bernoulli"
)
clinton_fitted = clinton_model.fit(
    draws=2000,
    target_accept=0.85,
    random_seed=SEED,
    idata_kwargs={"log_likelihood": True},
)

In [None]:
clinton_model

In [None]:
clinton_model.plot_priors();

In [None]:
az.plot_trace(clinton_fitted, compact=False);

In [None]:
az.summary(clinton_fitted)

In [None]:
clinton_model.predict(clinton_fitted, kind="response")

In [None]:
ax = az.plot_separation(clinton_fitted, y="vote", figsize=(9, 0.5));

In [None]:
# compute pointwise LOO
loo = az.loo(clinton_fitted, pointwise=True)

In [None]:
# plot kappa values
az.plot_khat(loo.pareto_k);

In [None]:
ax = az.plot_khat(loo.pareto_k.values.ravel())
sorted_kappas = np.sort(loo.pareto_k.values.ravel())

# find observation where the kappa value exceeds the threshold
threshold = sorted_kappas[-1:]
ax.axhline(threshold, ls="--", color="orange")
influential_observations = clinton_data.reset_index()[
    loo.pareto_k.values >= threshold
].index

for x in influential_observations:
    y = loo.pareto_k.values[x]
    ax.text(x, y + 0.01, str(x), ha="center", va="baseline")

In [None]:
clinton_data.reset_index()[loo.pareto_k.values >= threshold]

In [None]:
ax = az.plot_khat(loo.pareto_k)

# find observation where the kappa value exceeds the threshold
threshold = sorted_kappas[-6:].min()
ax.axhline(threshold, ls="--", color="orange")
influential_observations = clinton_data.reset_index()[
    loo.pareto_k.values >= threshold
].index

for x in influential_observations:
    y = loo.pareto_k.values[x]
    ax.text(x, y + 0.01, str(x), ha="center", va="baseline")

In [None]:
clinton_data.reset_index()[loo.pareto_k.values >= threshold]

In [None]:
clinton_data[clinton_data.age > 80]

In [None]:
clinton_data[(clinton_data.vote == "clinton") & (clinton_data.party_id == "republican")]

In [None]:
import matplotlib.patheffects as pe

ax = az.plot_separation(clinton_fitted, y="vote", figsize=(9, 0.5))

y = np.random.uniform(0.1, 0.5, size=len(influential_observations))

for x, y in zip(influential_observations, y):
    text = str(x)
    x = x / len(clinton_data)
    ax.scatter(x, y, marker="+", s=50, color="red", zorder=3)
    ax.text(
        x,
        y + 0.1,
        text,
        color="white",
        ha="center",
        va="bottom",
        path_effects=[pe.withStroke(linewidth=2, foreground="black")],
    )

In [None]:
parties = ["democrat", "independent", "republican"]
dem, ind, rep = [
    clinton_fitted.posterior["party_id:age"].sel({"party_id:age_dim": party})
    for party in parties
]

In [None]:
_, ax = plt.subplots()
for idx, x in enumerate([dem, ind, rep]):
    az.plot_dist(
        x, label=x["party_id:age_dim"].item(), plot_kwargs={"color": f"C{idx}"}, ax=ax
    )
ax.legend(loc="upper left");

In [None]:
age = np.arange(18, 91)
new_data = pd.DataFrame(
    {
        "age": np.tile(age, 3),
        "party_id": np.repeat(["democrat", "republican", "independent"], len(age)),
    }
)
new_data

In [None]:
clinton_model.predict(clinton_fitted, data=new_data)

In [None]:
# Select a sample of posterior values for the mean probability of voting for Clinton
vote_posterior = az.extract_dataset(clinton_fitted, num_samples=2000)["p"]

In [None]:
_, ax = plt.subplots(figsize=(7, 5))

for i, party in enumerate(["democrat", "republican", "independent"]):
    # Which rows in new_data correspond to party?
    idx = new_data.index[new_data["party_id"] == party].tolist()
    ax.plot(age, vote_posterior[idx], alpha=0.04, color=f"C{i}")

ax.set_ylabel("P(vote='clinton' | age)")
ax.set_xlabel("Age", fontsize=15)
ax.set_ylim(0, 1)
ax.set_xlim(18, 90);