# Bayesian Inference for a Probability: Beta Posterior Visualization

This notebook demonstrates Bayesian inference for a probability parameter using the Beta distribution as a conjugate prior. We visualize the posterior and its Laplace (Gaussian) approximation using Plotly, inspired by the style of the "Change of Measure" notebook.

## 1. Import Required Libraries

Import jax, jax.numpy, jax.scipy.stats.beta, numpy, and plotly.graph_objects for computation and visualization.

In [1]:
import jax
import jax.numpy as jnp
from jax.scipy.stats import beta, norm
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "plotly_white"

The **Beta distribution** is a continuous probability distribution defined on the interval $[0, 1]$. It is parameterized by two positive shape parameters, $a$ and $b$, and is commonly used to model random variables that represent probabilities or proportions.

### Probability Density Function

The probability density function (PDF) of the Beta distribution is:
$$
\mathrm{Beta}(x; a, b) = \frac{\Gamma(a + b)}{\Gamma(a)\Gamma(b)} x^{a-1} (1-x)^{b-1}
$$
where $x \in [0, 1]$ and $\Gamma(\cdot)$ is the gamma function.

### What is it good for?

- **Modeling probabilities:** The Beta distribution is ideal for modeling the probability of success in a Bernoulli or binomial process (e.g., the probability of heads in a coin toss).
- **Bayesian statistics:** It is the conjugate prior for the binomial and Bernoulli likelihoods, making Bayesian updating analytically tractable.
- **Flexibility:** By varying $a$ and $b$, the Beta distribution can take many shapes (uniform, U-shaped, bell-shaped, etc.).

### Bayesian Updating: From Prior to Posterior

Suppose we want to estimate the probability $\pi$ of success in a Bernoulli process. We start with a **prior** belief:
$$
\pi \sim \mathrm{Beta}(a_0, b_0)
$$

After observing data—say, $n$ trials with $k$ successes and $n-k$ failures—the **likelihood** is:
$$
\text{Likelihood} \propto \pi^k (1-\pi)^{n-k}
$$

By Bayes' theorem, the **posterior** is proportional to the product of the prior and the likelihood:
$$
p(\pi \mid \text{data}) \propto \pi^{a_0-1} (1-\pi)^{b_0-1} \cdot \pi^k (1-\pi)^{n-k}
$$
$$
= \pi^{a_0 + k - 1} (1-\pi)^{b_0 + n - k - 1}
$$

So, the posterior is also a Beta distribution:
$$
\pi \mid \text{data} \sim \mathrm{Beta}(a_0 + k,\, b_0 + n - k)
$$

This **conjugacy** makes the Beta distribution especially useful for Bayesian inference about probabilities. The parameters are simply updated by adding the observed counts of successes and failures to the prior parameters.

## 2. Set Prior and Data Parameters

Define prior parameters (`a0`, `b0`) and observed data (number of successes and failures).

In [2]:
# Prior parameters (can be changed)
a0 = 1.0  # prior "successes"
b0 = 1.0  # prior "failures"

# Observed data (can be changed)
num_successes = 3
num_failures = 2

## 3. Compute Beta Posterior Parameters

Update the Beta distribution parameters using the prior and observed data.

In [3]:
# Posterior parameters
a = a0 + num_successes
b = b0 + num_failures

## 4. Compute and Plot Beta Posterior Density with Plotly

Compute the Beta posterior density over a grid and plot it using Plotly, including mean and standard deviation markers.

In [8]:
# Grid for probability parameter
x = jnp.linspace(0, 1, 300)

# Beta posterior density
p = beta.pdf(x, a, b)

# Posterior mean and standard deviation
mean_p = a / (a + b)
var_p = a * b / ((a + b) ** 2 * (a + b + 1))
std_p = jnp.sqrt(var_p)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=x,
        y=p,
        mode="lines",
        line=dict(color="crimson", width=2),
        name="Beta posterior",
        fill="tozeroy",
        fillcolor="rgba(220,20,60,0.15)",
    )
)

# Mean marker
fig.add_trace(
    go.Scatter(
        x=[mean_p],
        y=[beta.pdf(mean_p, a, b)],
        mode="markers",
        marker=dict(color="green", size=10, symbol="x"),
        name="Posterior mean",
    )
)

# Std deviation interval
fig.add_trace(
    go.Scatter(
        x=[mean_p - std_p, mean_p + std_p],
        y=[beta.pdf(mean_p - std_p, a, b), beta.pdf(mean_p + std_p, a, b)],
        mode="markers+lines",
        marker=dict(color="green", size=7, symbol="line-ns"),
        line=dict(color="green", width=1, dash="dot"),
        name="Std deviation",
        showlegend=True,
    )
)

# Beta prior density
prior_p = beta.pdf(x, a0, b0)
prior_mean = a0 / (a0 + b0)
prior_var = a0 * b0 / ((a0 + b0) ** 2 * (a0 + b0 + 1))
prior_std = jnp.sqrt(prior_var)

# Add prior density
fig.add_trace(
    go.Scatter(
        x=x,
        y=prior_p,
        mode="lines",
        line=dict(color="royalblue", width=2, dash="dot"),
        name="Beta prior",
        fill=None,
    )
)

# Prior mean marker
fig.add_trace(
    go.Scatter(
        x=[prior_mean],
        y=[beta.pdf(prior_mean, a0, b0)],
        mode="markers",
        marker=dict(color="blue", size=10, symbol="x"),
        name="Prior mean",
    )
)

# Prior std deviation interval
fig.add_trace(
    go.Scatter(
        x=[prior_mean - prior_std, prior_mean + prior_std],
        y=[
            beta.pdf(prior_mean - prior_std, a0, b0),
            beta.pdf(prior_mean + prior_std, a0, b0),
        ],
        mode="markers+lines",
        marker=dict(color="blue", size=7, symbol="line-ns"),
        line=dict(color="blue", width=1, dash="dot"),
        name="Prior std deviation",
        showlegend=True,
    )
)

fig.update_layout(
    title="Beta Posterior Density $p(\\pi|a,b)$",
    xaxis_title="$\\pi$",
    yaxis_title="$p(\\pi|a,b)$",
    width=900,
    height=400,
    legend=dict(itemsizing="constant"),
)
fig.show()

## 5. Compute and Plot Log Posterior Density

Compute the log of the Beta posterior density and plot it using Plotly.

In [5]:
logp = beta.logpdf(x, a, b)

fig_log = go.Figure()
fig_log.add_trace(
    go.Scatter(
        x=x,
        y=logp,
        mode="lines",
        line=dict(color="crimson", width=2),
        name="log Beta posterior",
    )
)

fig_log.update_layout(
    title="Log Beta Posterior Density $\\log p(\\pi|a,b)$",
    xaxis_title="$\\pi$",
    yaxis_title="$\\log p(\\pi|a,b)$",
    width=900,
    height=400,
    legend=dict(itemsizing="constant"),
)
fig_log.show()

## 6. Compute Posterior Mean, Variance, and Mode

Calculate the mean, variance, standard deviation, and mode of the Beta posterior.

In [6]:
# Mode exists only if a > 1 and b > 1
if (a > 1) and (b > 1):
    mode_p = (a - 1) / (a + b - 2)
else:
    mode_p = None

print(f"Posterior mean: {mean_p:.3f}")
print(f"Posterior std deviation: {std_p:.3f}")
if mode_p is not None:
    print(f"Posterior mode: {mode_p:.3f}")
else:
    print("Posterior mode is undefined for a <= 1 or b <= 1.")

Posterior mean: 0.571
Posterior std deviation: 0.175
Posterior mode: 0.600


## 7. Laplace Approximation (Optional) and Plot Comparison

If the mode is defined, compute the Laplace (Gaussian) approximation at the mode and overlay it on the density and log-density plots for comparison.

The **Laplace approximation** is a method for approximating a probability distribution with a Gaussian (normal) distribution, centered at the mode of the target distribution. It is especially useful in Bayesian inference when the posterior distribution is not analytically tractable or is difficult to work with directly.

### How does Laplace approximation work?

Suppose we have a posterior density $p(\theta \mid \text{data})$. The Laplace approximation proceeds as follows:

1. **Find the mode:**  
    Compute the value $\theta^*$ that maximizes the log-posterior:
    $$
    \theta^* = \arg\max_\theta \log p(\theta \mid \text{data})
    $$

2. **Compute the curvature at the mode:**  
    Calculate the second derivative (the Hessian) of the log-posterior at $\theta^*$:
    $$
    H = \left. \frac{d^2}{d\theta^2} \log p(\theta \mid \text{data}) \right|_{\theta = \theta^*}
    $$

3. **Approximate with a Gaussian:**  
    The Laplace approximation is a normal distribution:
    $$
    q(\theta) = \mathcal{N}(\theta^*,\, -1/H)
    $$
    where the mean is the mode $\theta^*$, and the variance is the negative inverse of the Hessian.

### Why use Laplace approximation here?

- In this notebook, the posterior is a Beta distribution, which is tractable and easy to plot. However, in many real-world Bayesian models, the posterior is not a standard distribution and cannot be computed or sampled from directly.
- The Laplace approximation provides a **simple, analytic Gaussian approximation** to the posterior, which is easy to work with for further calculations (e.g., credible intervals, predictions).
- Comparing the Beta posterior and its Laplace approximation helps us **visualize how well a Gaussian can approximate the true posterior**. This is especially useful for intuition: the Beta is skewed for small sample sizes or extreme priors, and the Laplace (Gaussian) may not always be a good fit.

### In summary

- The Laplace approximation is a general-purpose tool for approximating complicated posteriors with a Gaussian.
- Here, it serves as a demonstration: we compare the true Beta posterior to its Laplace (Gaussian) approximation to see how well the latter matches the former.
- This is important in Bayesian inference, especially for more complex models where the posterior is not available in closed form.

The Laplace approximation is centered at the **mode** (the value where the posterior is maximized), not the mean, because it is based on a second-order Taylor expansion of the **log-posterior** around its maximum.

Here's why:

- The Laplace approximation fits a Gaussian to the posterior by matching the curvature (second derivative) of the log-posterior at its peak.
- The peak of the posterior is the **mode**, not necessarily the mean (especially for skewed distributions).
- The mean and mode are only the same for symmetric distributions (like the normal), but for skewed distributions (like the Beta with small counts), they differ.
- The Taylor expansion gives the best local Gaussian approximation at the mode, where the log-posterior is maximized and its first derivative is zero.

**Summary:**  
Laplace approximation uses the mode because it is the point of maximum probability and provides the most accurate local quadratic (Gaussian) approximation to the posterior.

In [7]:
if mode_p is not None:
    # Compute second derivative (Hessian) of log-posterior at the mode
    hess_logpmode = jax.hessian(lambda x: beta.logpdf(x, a, b))(mode_p)
    laplace_std = jnp.sqrt(1.0 / -hess_logpmode)

    # Laplace (Gaussian) approximation
    laplace_pdf = norm.pdf(x, loc=mode_p, scale=laplace_std)

    # Overlay on density plot
    fig_laplace = go.Figure()
    fig_laplace.add_trace(
        go.Scatter(
            x=x,
            y=p,
            mode="lines",
            line=dict(color="crimson", width=2),
            name="Beta posterior",
            fill="tozeroy",
            fillcolor="rgba(220,20,60,0.15)",
        )
    )
    fig_laplace.add_trace(
        go.Scatter(
            x=x,
            y=laplace_pdf,
            mode="lines",
            line=dict(color="royalblue", width=2, dash="dash"),
            name="Laplace approx.",
        )
    )
    fig_laplace.add_trace(
        go.Scatter(
            x=[mode_p],
            y=[beta.pdf(mode_p, a, b)],
            mode="markers",
            marker=dict(color="red", size=10, symbol="star"),
            name="Posterior mode",
        )
    )
    fig_laplace.update_layout(
        title="Beta Posterior and Laplace Approximation",
        xaxis_title="$\\pi$",
        yaxis_title="Density",
        width=900,
        height=400,
        legend=dict(itemsizing="constant"),
    )
    fig_laplace.show()

    # Overlay on log-density plot
    laplace_logpdf = jnp.log(laplace_pdf)
    fig_log_laplace = go.Figure()
    fig_log_laplace.add_trace(
        go.Scatter(
            x=x,
            y=logp,
            mode="lines",
            line=dict(color="crimson", width=2),
            name="log Beta posterior",
        )
    )
    fig_log_laplace.add_trace(
        go.Scatter(
            x=x,
            y=laplace_logpdf,
            mode="lines",
            line=dict(color="royalblue", width=2, dash="dash"),
            name="log Laplace approx.",
        )
    )
    fig_log_laplace.add_trace(
        go.Scatter(
            x=[mode_p],
            y=[beta.logpdf(mode_p, a, b)],
            mode="markers",
            marker=dict(color="red", size=10, symbol="star"),
            name="Posterior mode",
        )
    )
    fig_log_laplace.update_layout(
        title="Log Beta Posterior and Laplace Approximation",
        xaxis_title="$\\pi$",
        yaxis_title="Log Density",
        width=900,
        height=400,
        legend=dict(itemsizing="constant"),
    )
    fig_log_laplace.show()
else:
    print("Laplace approximation not shown: mode is undefined for a <= 1 or b <= 1.")