Based on https://bambinos.github.io/bambi/notebooks/hierarchical_binomial_bambi.html

In [None]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np

from matplotlib.lines import Line2D
from matplotlib.patches import Patch

In [None]:
random_seed = 1234

In [None]:
df = bmb.load_data("batting")

# Then clean some of the data
df["AB"] = df["AB"].replace(0, np.nan)
df = df.dropna()
df["batting_avg"] = df["H"] / df["AB"]
df = df[df["yearID"] >= 2016]
df = df.iloc[0:15]
df.head(5)

In [None]:
BLUE = "#2a5674"
RED = "#b13f64"

In [None]:
_, ax = plt.subplots(figsize=(10, 6))

# Customize x limits.
# This adds space on the left side to indicate percentage of hits.
ax.set_xlim(-120, 320)

# Add dots for the the number of hits and the times at bat
ax.scatter(df["H"], list(range(15)), s=140, color=RED, zorder=10)
ax.scatter(df["AB"], list(range(15)), s=140, color=BLUE, zorder=10)

# Also a line connecting them
ax.hlines(list(range(15)), df["H"], df["AB"], color="#b3b3b3", lw=4)

ax.axvline(ls="--", lw=1.4, color="#a3a3a3")
ax.hlines(list(range(15)), -110, -50, lw=6, color="#b3b3b3", capstyle="round")
ax.scatter(60 * df["batting_avg"] - 110, list(range(15)), s=28, color=RED, zorder=10)

# Add the percentage of hits
for j in range(15):
    text = f"{round(df['batting_avg'].iloc[j] * 100)}%"
    ax.text(-12, j, text, ha="right", va="center", fontsize=14, color="#333")

# Customize tick positions and labels
ax.yaxis.set_ticks(list(range(15)))
ax.yaxis.set_ticklabels(df["playerID"])
ax.xaxis.set_ticks(range(0, 400, 100))

# Create handles for the legend (just dots and labels)
handles = [
    Line2D(
        [0],
        [0],
        label="Hits",
        marker="o",
        color="None",
        markeredgewidth=0,
        markerfacecolor=RED,
        markersize=13,
    ),
    Line2D(
        [0],
        [0],
        label="At Bat",
        marker="o",
        color="None",
        markeredgewidth=0,
        markerfacecolor=BLUE,
        markersize=12,
    ),
]

# Add legend on top-right corner
legend = ax.legend(handles=handles, loc=1, fontsize=14, handletextpad=0.4, frameon=True)

# Finally add labels and a title
ax.set_xlabel("Count", fontsize=14)
ax.set_ylabel("Player", fontsize=14)
ax.set_title("How often do batters hit the ball?", fontsize=20);

In [None]:
model_non_hierarchical = bmb.Model("p(H, AB) ~ 0 + playerID", df, family="binomial")
model_non_hierarchical

In [None]:
idata_non_hierarchical = model_non_hierarchical.fit(random_seed=random_seed)

In [None]:
az.plot_trace(
    idata_non_hierarchical, compact=False, backend_kwargs={"layout": "constrained"}
);

In [None]:
model_hierarchical = bmb.Model("p(H, AB) ~ 1 + (1|playerID)", df, family="binomial")
model_hierarchical

In [None]:
idata_hierarchical = model_hierarchical.fit(random_seed=random_seed)

In [None]:
idata_prior = model_hierarchical.prior_predictive()
prior = az.extract_dataset(idata_prior, group="prior_predictive")["p(H, AB)"]

In [None]:
# We define this function because this plot is going to be repeated below.
def plot_prior_predictive(df, prior):
    AB = df["AB"].values
    H = df["H"].values

    fig, axes = plt.subplots(5, 3, figsize=(10, 6), sharex="col")

    for idx, ax in enumerate(axes.ravel()):
        pps = prior.sel({"__obs__": idx})
        ab = AB[idx]
        h = H[idx]
        hist = ax.hist(pps / ab, bins=25, color="#a3a3a3")
        ax.axvline(h / ab, color=RED, lw=2)
        ax.set_yticks([])
        ax.tick_params(labelsize=12)

    fig.subplots_adjust(left=0.025, right=0.975, hspace=0.05, wspace=0.05, bottom=0.125)
    fig.legend(
        handles=[Line2D([0], [0], label="Observed proportion", color=RED, linewidth=2)],
        handlelength=1.5,
        handletextpad=0.8,
        borderaxespad=0,
        frameon=True,
        fontsize=11,
        bbox_to_anchor=(0.975, 0.92),
        loc="right",
    )
    fig.text(
        0.5,
        0.05,
        "Prior probability of hitting",
        fontsize=15,
        ha="center",
        va="baseline",
    )

In [None]:
plot_prior_predictive(df, prior)

In [None]:
priors = {
    "Intercept": bmb.Prior("Normal", mu=0, sigma=1),
    "1|playerID": bmb.Prior("Normal", mu=0, sigma=bmb.Prior("HalfNormal", sigma=1)),
}
model_hierarchical = bmb.Model(
    "p(H, AB) ~ 1 + (1|playerID)", df, family="binomial", priors=priors
)
model_hierarchical

In [None]:
model_hierarchical.build()
idata_prior = model_hierarchical.prior_predictive()
prior = az.extract_dataset(idata_prior, group="prior_predictive")["p(H, AB)"]
plot_prior_predictive(df, prior)

In [None]:
idata_hierarchical = model_hierarchical.fit(random_seed=random_seed)

In [None]:
idata_hierarchical = model_hierarchical.fit(
    tune=2000, draws=2000, target_accept=0.95, random_seed=random_seed
)

In [None]:
var_names = ["Intercept", "1|playerID", "1|playerID_sigma"]
az.plot_trace(
    idata_hierarchical,
    var_names=var_names,
    compact=False,
    backend_kwargs={"layout": "constrained"},
);

In [None]:
model_non_hierarchical.predict(idata_non_hierarchical)
model_hierarchical.predict(idata_hierarchical)

In [None]:
_, ax = plt.subplots(figsize=(8, 8))

# Add vertical line for the global probability of hitting
ax.axvline(x=(df["H"] / df["AB"]).mean(), ls="--", color="black", alpha=0.5)

# Create forestplot with ArviZ, only for the mean.
az.plot_forest(
    [idata_non_hierarchical, idata_hierarchical],
    var_names="p",
    combined=True,
    colors=["#666666", RED],
    linewidth=2.6,
    markersize=8,
    ax=ax,
)

# Create custom y axis tick labels
ylabels = [
    f"H: {round(h)}, AB: {round(ab)}" for h, ab in zip(df["H"].values, df["AB"].values)
]
ylabels = list(reversed(ylabels))

# Put the labels for the y axis in the mid of the original location of the tick marks.
ax.set_yticklabels(ylabels, ha="right")

# Create legend
handles = [
    Patch(label="Non-hierarchical", facecolor="#666666"),
    Patch(label="Hierarchical", facecolor=RED),
    Line2D([0], [0], label="Mean probability", ls="--", color="black", alpha=0.5),
]

legend = ax.legend(handles=handles, loc=4, fontsize=14, frameon=True, framealpha=0.8);