Based on https://bambinos.github.io/bambi/notebooks/categorical_regression.html

In [None]:
import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import warnings

from matplotlib.lines import Line2D

warnings.simplefilter(action="ignore", category=FutureWarning)

In [None]:
SEED = 1234
az.style.use("arviz-darkgrid")

In [None]:
rng = np.random.default_rng(SEED)
x = np.hstack(
    [rng.normal(m, s, size=50) for m, s in zip([-2.5, 0, 2.5], [1.2, 0.5, 1.2])]
)
y = np.array(["A"] * 50 + ["B"] * 50 + ["C"] * 50)

colors = ["C0"] * 50 + ["C1"] * 50 + ["C2"] * 50
plt.scatter(x, np.random.uniform(size=150), color=colors)
plt.xlabel("x")
plt.ylabel("y");

In [None]:
data = pd.DataFrame({"y": y, "x": x})
model = bmb.Model("y ~ x", data, family="categorical")
idata = model.fit()

In [None]:
x_new = np.linspace(-5, 5, num=200)
model.predict(idata, data=pd.DataFrame({"x": x_new}))
p = idata.posterior["p"].sel(draw=slice(0, None, 10))

for j, g in enumerate("ABC"):
    plt.plot(
        x_new,
        p.sel({"y_dim": g}).stack(samples=("chain", "draw")),
        color=f"C{j}",
        alpha=0.2,
    )

plt.xlabel("x")
plt.ylabel("y");

In [None]:
iris = sns.load_dataset("iris")
iris.head(3)

In [None]:
sns.pairplot(iris, hue="species");

In [None]:
model = bmb.Model(
    "species ~ sepal_length + sepal_width + petal_length + petal_width",
    iris,
    family="categorical",
)
idata = model.fit()
az.summary(idata)

In [None]:
az.plot_trace(idata);

In [None]:
length = [
    1.3,
    1.32,
    1.32,
    1.4,
    1.42,
    1.42,
    1.47,
    1.47,
    1.5,
    1.52,
    1.63,
    1.65,
    1.65,
    1.65,
    1.65,
    1.68,
    1.7,
    1.73,
    1.78,
    1.78,
    1.8,
    1.85,
    1.93,
    1.93,
    1.98,
    2.03,
    2.03,
    2.31,
    2.36,
    2.46,
    3.25,
    3.28,
    3.33,
    3.56,
    3.58,
    3.66,
    3.68,
    3.71,
    3.89,
    1.24,
    1.3,
    1.45,
    1.45,
    1.55,
    1.6,
    1.6,
    1.65,
    1.78,
    1.78,
    1.8,
    1.88,
    2.16,
    2.26,
    2.31,
    2.36,
    2.39,
    2.41,
    2.44,
    2.56,
    2.67,
    2.72,
    2.79,
    2.84,
]
choice = [
    "I",
    "F",
    "F",
    "F",
    "I",
    "F",
    "I",
    "F",
    "I",
    "I",
    "I",
    "O",
    "O",
    "I",
    "F",
    "F",
    "I",
    "O",
    "F",
    "O",
    "F",
    "F",
    "I",
    "F",
    "I",
    "F",
    "F",
    "F",
    "F",
    "F",
    "O",
    "O",
    "F",
    "F",
    "F",
    "F",
    "O",
    "F",
    "F",
    "I",
    "I",
    "I",
    "O",
    "I",
    "I",
    "I",
    "F",
    "I",
    "O",
    "I",
    "I",
    "F",
    "F",
    "F",
    "F",
    "F",
    "F",
    "F",
    "O",
    "F",
    "I",
    "F",
    "F",
]

sex = ["Male"] * 32 + ["Female"] * 31
data = pd.DataFrame({"choice": choice, "length": length, "sex": sex})
data["choice"] = pd.Categorical(
    data["choice"].map({"I": "Invertebrates", "F": "Fish", "O": "Other"}),
    ["Other", "Invertebrates", "Fish"],
    ordered=True,
)
data.head(3)

In [None]:
model = bmb.Model("choice ~ length + sex", data, family="categorical")
idata = model.fit()

In [None]:
bmb.interpret.plot_predictions(
    model,
    idata,
    ["length", "sex"],
    subplot_kwargs={"main": "length", "group": "estimate_dim", "panel": "sex"},
    fig_kwargs={"figsize": (12, 4)},
    legend=True,
);

In [None]:
model.predict(idata, kind="pps")

ax = az.plot_ppc(idata)
ax.set_xticks([0.5, 1.5, 2.5])
ax.set_xticklabels(model.response_component.term.levels)
ax.set_xlabel("Choice")
ax.set_ylabel("Probability");