# Dirichlet-Multinomial Bayesian Inference: Interactive Visualization

This notebook demonstrates Bayesian inference for categorical probabilities using the Dirichlet prior and the multinomial likelihood. We visualize the prior, posterior, and posterior predictive distributions with interactive controls for prior and data parameters.

## 1. Import Required Libraries

Import `numpy`, `jax`, `jax.numpy`, `scipy.stats.multinomial`, `scipy.stats.betabinom`, `plotly.graph_objects`, `plotly.figure_factory`, `ipywidgets`, and `IPython.display` for computation, simulation, interactivity, and visualization.

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from scipy.stats import multinomial, betabinom
import plotly.graph_objects as go
import plotly.figure_factory as ff
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML

pio.templates.default = "plotly_white"
display(
    HTML(
        '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
    )
)

## 2. Dirichlet-Multinomial Model: Mathematical Background

The **Dirichlet-Multinomial model** is a Bayesian model for categorical/multinomial data.

### Prior: Dirichlet Distribution

The Dirichlet prior for the probability vector $\mathbf{f} = [f_1, f_2, f_3]$ is:
$$
p(\mathbf{f} \mid \boldsymbol{\alpha}) = \frac{1}{B(\boldsymbol{\alpha})} \prod_{k=1}^K f_k^{\alpha_k-1}
$$
where $f_k \geq 0$, $\sum_{k=1}^K f_k = 1$, $\alpha_k > 0$, and $B(\boldsymbol{\alpha})$ is the multivariate Beta function.

### Likelihood: Multinomial

Given $N$ independent categorical trials with counts $\mathbf{n} = [n_1, n_2, n_3]$:
$$
p(\mathbf{n} \mid \mathbf{f}) = \frac{N!}{n_1! n_2! n_3!} \prod_{k=1}^3 f_k^{n_k}
$$

### Posterior: Dirichlet

By conjugacy, the posterior is also Dirichlet:
$$
p(\mathbf{f} \mid \mathbf{n}, \boldsymbol{\alpha}) = \mathrm{Dirichlet}(\boldsymbol{\alpha} + \mathbf{n})
$$

### Posterior Predictive: Beta-Binomial Marginals

The predictive probability for $y$ counts in category $k$ in $N$ future trials:
$$
p(y \mid N, \boldsymbol{\alpha}') = \mathrm{BetaBinomial}(y; N, \alpha_k', \sum_{j \neq k} \alpha_j')
$$
where $\boldsymbol{\alpha}' = \boldsymbol{\alpha} + \mathbf{n}$.


**Why is the posterior predictive Beta-Binomial, not Dirichlet-Multinomial?**

- The **Dirichlet-Multinomial** distribution describes the marginal distribution of counts $\mathbf{n}$ in *multiple* categories, when the category probabilities $\mathbf{f}$ are integrated out under a Dirichlet prior.
- The **Beta-Binomial** is a special case for a *single* category: it gives the marginal distribution for the number of "successes" in one category (say, category $k$), integrating out the uncertainty in $f_k$ under the posterior Dirichlet.

**Posterior Predictive:**
- After observing data, the posterior for $\mathbf{f}$ is Dirichlet.
- For *future* data, the predictive distribution for the full vector of counts is Dirichlet-Multinomial.
- But if we focus on the *marginal* predictive for the count in a single category (e.g., $y$ out of $N$ future trials in category $k$), this marginal is Beta-Binomial:
    $$
    p(y \mid N, \alpha_k', \alpha_{-k}') = \mathrm{BetaBinomial}(y; N, \alpha_k', \sum_{j \neq k} \alpha_j')
    $$
    where $\alpha_k'$ is the posterior Dirichlet parameter for category $k$.

**Summary:**
- **Dirichlet-Multinomial:** Joint predictive for all categories.
- **Beta-Binomial:** Marginal predictive for one category (summing/integrating out the others).

This is why, when plotting the marginal predictive for each category, we use the Beta-Binomial distribution.


## 3. Set Prior and Data Parameters (Interactive)

Use the sliders below to set the Dirichlet prior parameters ($\alpha_1, \alpha_2, \alpha_3$), number of trials $N$, and the true probabilities $[p_1, p_2, p_3]$ for data simulation. The probabilities are constrained so that $p_1 + p_2 + p_3 = 1$.

In [2]:
# Prior parameter sliders
alpha1_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1, description="α₁ (prior):"
)
alpha2_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1, description="α₂ (prior):"
)
alpha3_slider = widgets.FloatSlider(
    value=1.0, min=0.1, max=5.0, step=0.1, description="α₃ (prior):"
)

# Number of trials
N_slider = widgets.IntSlider(
    value=15, min=1, max=100, step=1, description="N (trials):"
)

# True probabilities sliders (with sum-to-1 constraint)
p1_slider = widgets.FloatSlider(
    value=0.3, min=0.0, max=1.0, step=0.01, description="p₁ (true):"
)


def update_p2_range(*args):
    p1 = p1_slider.value
    p2_slider.max = 1.0 - p1
    if p2_slider.value > p2_slider.max:
        p2_slider.value = p2_slider.max


p1_slider.observe(update_p2_range, names="value")

p2_slider = widgets.FloatSlider(
    value=0.3, min=0.0, max=1.0, step=0.01, description="p₂ (true):"
)


def update_p3_display(*args):
    p1 = p1_slider.value
    p2 = p2_slider.value
    p3 = max(0.0, 1.0 - p1 - p2)
    p3_label.value = f"<b>p₃ (true):</b> {p3:.2f}"


p2_slider.observe(update_p3_display, names="value")
p1_slider.observe(update_p3_display, names="value")
p3_label = widgets.HTML(value="<b>p₃ (true):</b> 0.40")

ui = widgets.VBox(
    [
        widgets.HBox([alpha1_slider, alpha2_slider, alpha3_slider]),
        N_slider,
        widgets.HBox([p1_slider, p2_slider, p3_label]),
    ]
)
display(ui)

VBox(children=(HBox(children=(FloatSlider(value=1.0, description='α₁ (prior):', max=5.0, min=0.1), FloatSlider…

## 4. Simulate Multinomial Data

Simulate observed counts $\mathbf{n} = [n_1, n_2, n_3]$ from a multinomial distribution with parameters $N$ and $[p_1, p_2, p_3]$. Display the observed counts.

In [3]:
def simulate_multinomial_data(N, p_vec, rng_seed=42):
    rng = np.random.default_rng(rng_seed)
    counts = multinomial.rvs(n=N, p=p_vec, random_state=rng)
    return counts

## 5. Compute Posterior Parameters

Update the Dirichlet parameters: $\alpha'_k = \alpha_k + n_k$ for each category $k$.

In [4]:
def compute_posterior_params(alpha_vec, counts):
    return alpha_vec + counts

## 6. Plot Observed Data

Use Plotly to create a bar chart showing observed counts for each category, and overlay the expected counts from the true probabilities.

In [19]:
def plot_observed_counts(counts, N, p_vec):
    categories = ["Cat 1", "Cat 2", "Cat 3"]
    expected = np.array(p_vec) * N
    fig = go.Figure()
    fig.add_trace(
        go.Bar(x=categories, y=counts, name="Observed", marker_color="#636EFA")
    )
    fig.add_trace(
        go.Bar(
            x=categories,
            y=expected,
            name="Expected (true p)",
            marker_color="#00CC96",
            opacity=0.5,
        )
    )
    fig.update_layout(
        title="Observed vs Expected Counts",
        yaxis_title="Count",
        barmode="group",
        width=1200,
        height=350,
    )
    display(fig)

In [20]:
def plot_observed_counts_interactive(N, p1, p2):
    p_vec = np.array([p1, p2, 1.0 - p1 - p2])
    counts = simulate_multinomial_data(N, p_vec)
    plot_observed_counts(counts, N, p_vec)


out = widgets.interactive_output(
    plot_observed_counts_interactive,
    {
        "N": N_slider,
        "p1": p1_slider,
        "p2": p2_slider,
    },
)
display(ui)
display(out)

VBox(children=(HBox(children=(FloatSlider(value=1.8, description='α₁ (prior):', max=5.0, min=0.1), FloatSlider…

Output()

## 7. Plot Prior and Posterior Dirichlet Density (Ternary Plot)

Visualize the prior and posterior Dirichlet densities on a ternary plot using Plotly. Mark the true probability vector and the posterior mean.

In [83]:
def dirichlet_pdf_grid(alpha, grid_size=80):
    # Generate grid over simplex
    f1 = np.linspace(0, 1, grid_size)
    f2 = np.linspace(0, 1, grid_size)
    F1, F2 = np.meshgrid(f1, f2)
    F3 = 1.0 - F1 - F2
    mask = (F3 >= 0) & (F3 <= 1)
    f1v = F1[mask]
    f2v = F2[mask]
    f3v = F3[mask]
    from scipy.stats import dirichlet

    pdf = dirichlet.pdf(np.stack([f1v, f2v, f3v], axis=0), alpha)
    return f1v, f2v, f3v, pdf


def plot_dirichlet_ternary(alpha, alpha_post, p_vec, title="Dirichlet Densities"):
    # Prior
    f1v, f2v, f3v, prior_pdf = dirichlet_pdf_grid(alpha)
    # Posterior
    f1v_post, f2v_post, f3v_post, post_pdf = dirichlet_pdf_grid(alpha_post)
    # True and posterior mean
    post_mean = alpha_post / np.sum(alpha_post)
    fig = go.Figure()

    # Prior density
    fig.add_trace(
        go.Scatterternary(
            a=f1v,
            b=f2v,
            c=f3v,
            mode="markers",
            marker=dict(
                color=prior_pdf,
                colorscale="Blues",
                size=4,
                opacity=0.9,
                colorbar=dict(
                    title="Prior PDF",
                    x=0.0,  # Move colorbar to the left
                    y=0.5,  # Center vertically
                ),
            ),
            name="Prior",
            showlegend=False,
        )
    )
    # Posterior density
    fig.add_trace(
        go.Scatterternary(
            a=f1v_post,
            b=f2v_post,
            c=f3v_post,
            mode="markers",
            marker=dict(
                color=post_pdf,
                colorscale="Reds",
                size=4,
                opacity=0.4,
                colorbar=dict(title="Posterior PDF"),
            ),
            name="Posterior",
            showlegend=False,
        )
    )

    # True p
    fig.add_trace(
        go.Scatterternary(
            a=[p_vec[0]],
            b=[p_vec[1]],
            c=[p_vec[2]],
            mode="markers",
            marker=dict(color="green", size=12, symbol="x"),
            name="True p",
        )
    )

    # Prior mean
    prior_mean = alpha / np.sum(alpha)
    fig.add_trace(
        go.Scatterternary(
            a=[prior_mean[0]],
            b=[prior_mean[1]],
            c=[prior_mean[2]],
            mode="markers",
            marker=dict(color="blue", size=12, symbol="star"),
            name="Prior mean",
        )
    )

    # Posterior mean
    fig.add_trace(
        go.Scatterternary(
            a=[post_mean[0]],
            b=[post_mean[1]],
            c=[post_mean[2]],
            mode="markers",
            marker=dict(color="red", size=12, symbol="star"),
            name="Posterior mean",
        )
    )
    fig.update_layout(
        title=title,
        ternary=dict(
            sum=1,
            aaxis=dict(title="f₁"),
            baxis=dict(title="f₂"),
            caxis=dict(title="f₃"),
        ),
        width=1200,
        height=500,
        showlegend=True,
    )
    fig.update_layout(
        legend=dict(orientation="h", yanchor="bottom", y=-0.3, xanchor="center", x=0.5),
    )
    display(fig)

In [84]:
def plot_dirichlet_ternary_interactive(alpha1, alpha2, alpha3, N, p1, p2):
    p_vec = np.array([p1, p2, 1.0 - p1 - p2])
    counts = simulate_multinomial_data(N, p_vec)
    alpha = np.array([alpha1, alpha2, alpha3])
    alpha_post = compute_posterior_params(alpha, counts)
    plot_dirichlet_ternary(alpha, alpha_post, p_vec)


out = widgets.interactive_output(
    plot_dirichlet_ternary_interactive,
    {
        "alpha1": alpha1_slider,
        "alpha2": alpha2_slider,
        "alpha3": alpha3_slider,
        "N": N_slider,
        "p1": p1_slider,
        "p2": p2_slider,
    },
)
display(ui)
display(out)

VBox(children=(HBox(children=(FloatSlider(value=2.3000000000000003, description='α₁ (prior):', max=5.0, min=0.…

Output()

## 8. Plot Marginal Posterior Predictive Distributions

For each category, plot the marginal posterior predictive distribution for the number of successes in $N$ future trials using the Beta-Binomial distribution.

In [87]:
def plot_marginal_predictives(N, alpha_post):
    fig = go.Figure()
    y = np.arange(N + 1)
    colors = ["#636EFA", "#EF553B", "#00CC96"]
    for k in range(3):
        pred = betabinom.pmf(y, N, alpha_post[k], np.sum(alpha_post) - alpha_post[k])
        fig.add_trace(
            go.Bar(
                x=y, y=pred, name=f"Cat {k + 1}", marker_color=colors[k], opacity=0.7
            )
        )
    fig.update_layout(
        title="Marginal Posterior Predictive (Beta-Binomial) for Each Category",
        xaxis_title="Number of Successes in N Trials",
        yaxis_title="Probability",
        barmode="group",
        width=1200,
        height=500,
    )
    # For Jupyter, use display(fig) to show the figure with the specified size.
    # If the width/height still doesn't take effect, try using fig.show() instead of display(fig)
    # or set pio.renderers.default = "notebook" at the top of your notebook.
    display(fig)

In [88]:
def plot_marginal_predictives_interactive(alpha1, alpha2, alpha3, N, p1, p2):
    # Compute p3 and posterior parameters
    p3 = max(0.0, 1.0 - p1 - p2)
    p_vec = np.array([p1, p2, p3])
    alpha = np.array([alpha1, alpha2, alpha3])
    counts = simulate_multinomial_data(N, p_vec)
    alpha_post = compute_posterior_params(alpha, counts)
    plot_marginal_predictives(N, alpha_post)


out_predictives = widgets.interactive_output(
    plot_marginal_predictives_interactive,
    {
        "alpha1": alpha1_slider,
        "alpha2": alpha2_slider,
        "alpha3": alpha3_slider,
        "N": N_slider,
        "p1": p1_slider,
        "p2": p2_slider,
    },
)
display(ui)
display(out_predictives)

VBox(children=(HBox(children=(FloatSlider(value=2.3000000000000003, description='α₁ (prior):', max=5.0, min=0.…

Output()

## 9. Display Summary and Mathematical Explanation

Show a summary of the Bayesian updating process, the formulas for the prior, likelihood, posterior, and predictive distributions, and the interpretation of the results.

In [78]:
def display_summary(alpha, counts, alpha_post, N, p_vec):
    summary = f"""
**Bayesian Updating for Dirichlet-Multinomial:**

- Prior: $\\mathbf{{f}} \\sim \\mathrm{{Dirichlet}}(\\boldsymbol{{\\alpha}})$ with $\\boldsymbol{{\\alpha}} = [{alpha[0]:.2f}, {alpha[1]:.2f}, {alpha[2]:.2f}]$
- Data: $N={N}$, observed counts $\\mathbf{{n}} = [{counts[0]}, {counts[1]}, {counts[2]}]$
- Posterior: $\\mathbf{{f}} \\mid \\mathbf{{n}} \\sim \\mathrm{{Dirichlet}}(\\boldsymbol{{\\alpha}} + \\mathbf{{n}}) = \\mathrm{{Dirichlet}}([{alpha_post[0]:.2f}, {alpha_post[1]:.2f}, {alpha_post[2]:.2f}])$

**Posterior Mean:** $\\mathbb{{E}}[f_k \\mid \\mathbf{{n}}] = \\frac{{\\alpha_k + n_k}}{{\\sum_j (\\alpha_j + n_j)}}$

**Posterior Predictive (marginal for category $k$):**
$$
p(y \\mid N, \\alpha_k', \\alpha_j') = \\mathrm{{BetaBinomial}}(y; N, \\alpha_k', \\sum_{{j \\neq k}} \\alpha_j')
$$

- The Dirichlet prior expresses beliefs about the category probabilities before seeing data.
- The multinomial likelihood models the observed counts.
- The posterior Dirichlet combines prior and data.
- The Beta-Binomial gives the predictive distribution for future counts in each category.
"""
    display(Markdown(summary))

## 10. Interactive Controls for Dirichlet Inference

Combine all widgets and plots into an interactive dashboard using ipywidgets, updating all visualizations and summaries as parameters change.

In [89]:
def dirichlet_multinomial_dashboard(alpha1, alpha2, alpha3, N, p1, p2):
    # Compute p3
    p3 = max(0.0, 1.0 - p1 - p2)
    p_vec = np.array([p1, p2, p3])
    alpha = np.array([alpha1, alpha2, alpha3])
    # Simulate data
    counts = simulate_multinomial_data(N, p_vec)
    # Posterior
    alpha_post = compute_posterior_params(alpha, counts)
    # Plots
    plot_observed_counts(counts, N, p_vec)
    plot_dirichlet_ternary(alpha, alpha_post, p_vec)
    plot_marginal_predictives(N, alpha_post)
    display_summary(alpha, counts, alpha_post, N, p_vec)


out = widgets.interactive_output(
    dirichlet_multinomial_dashboard,
    {
        "alpha1": alpha1_slider,
        "alpha2": alpha2_slider,
        "alpha3": alpha3_slider,
        "N": N_slider,
        "p1": p1_slider,
        "p2": p2_slider,
    },
)

display(ui)
display(out)

VBox(children=(HBox(children=(FloatSlider(value=2.3000000000000003, description='α₁ (prior):', max=5.0, min=0.…

Output()