# Gaussian Mean Bayesian Inference: Interactive Visualization

This notebook demonstrates Bayesian inference for the mean of a Gaussian distribution with known variance (set to 1), using a Gaussian conjugate prior. We visualize the prior, likelihood, posterior, and posterior predictive distributions with interactive controls for prior and data parameters.

## 1. Import Required Libraries

Import `numpy`, `jax`, `jax.numpy`, `scipy.stats`, `plotly.graph_objects`, `plotly.subplots`, `ipywidgets`, and `IPython.display` for computation, simulation, interactivity, and visualization.

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from scipy.stats import norm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML

pio.templates.default = "plotly_white"
# For LaTeX rendering in Jupyter
display(
    HTML(
        '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
    )
)

## 2. Mathematical Background: Gaussian Mean Inference with Conjugate Prior

We consider Bayesian inference for the mean $\mu$ of a Gaussian distribution with known variance ($\sigma^2 = 1$).

### Likelihood

Given $N$ i.i.d. samples $x_1, \ldots, x_N \sim \mathcal{N}(\mu, 1)$, the likelihood for $\mu$ is:
$$
p(\mathbf{x} \mid \mu) = \prod_{i=1}^N \frac{1}{\sqrt{2\pi}} \exp\left(-\frac{(x_i - \mu)^2}{2}\right)
$$

### Prior: Gaussian on Mean

We place a Gaussian prior on $\mu$:
$$
p(\mu) = \mathcal{N}(\mu \mid m, v)
$$

### Posterior

By conjugacy, the posterior is also Gaussian:
$$
p(\mu \mid \mathbf{x}) = \mathcal{N}(\mu \mid m', v')
$$
where
$$
v' = \left(\frac{N}{1} + \frac{1}{v}\right)^{-1}, \qquad m' = v' \left(\frac{N \bar{x}}{1} + \frac{m}{v}\right)
$$
with $\bar{x} = \frac{1}{N} \sum_{i=1}^N x_i$.

### Posterior Predictive

The predictive density for a new $x^*$ is:
$$
p(x^* \mid \mathbf{x}) = \int p(x^* \mid \mu) p(\mu \mid \mathbf{x}) d\mu = \mathcal{N}(x^* \mid m', v' + 1)
$$

## 3. Set Parameters (Interactive Controls)

Use the sliders below to set the number of data points $N$, the true mean $\mu$, and (optionally) the Gaussian prior parameters $m$ and $v$. Toggle the conjugate prior on/off.

In [13]:
# Data sliders
N_slider = widgets.IntSlider(value=30, min=1, max=100, step=1, description="N (data):")
mu_slider = widgets.FloatSlider(
    value=0.0, min=-3.0, max=3.0, step=0.1, description="μ (true):"
)

# Conjugate prior toggle
cp_checkbox = widgets.Checkbox(value=False, description="Conjugate Prior?")

# Prior parameter sliders
m_slider = widgets.FloatSlider(
    value=0.0,
    min=-3.0,
    max=3.0,
    step=0.1,
    description="m (prior mean):",
    style={"description_width": "initial"},
)
v_slider = widgets.FloatSlider(
    value=1.0,
    min=0.01,
    max=10.0,
    step=0.1,
    description="v (prior var):",
    style={"description_width": "initial"},
)


def prior_box(cp):
    if cp:
        return widgets.HBox([m_slider, v_slider])
    else:
        return widgets.HTML("")


ui = widgets.VBox(
    [
        widgets.HBox([N_slider, mu_slider, m_slider, v_slider]),
        cp_checkbox,
        widgets.interactive_output(prior_box, {"cp": cp_checkbox}),
    ]
)
display(ui)

VBox(children=(HBox(children=(IntSlider(value=30, description='N (data):', min=1), FloatSlider(value=0.0, desc…

## 4. Simulate Gaussian Data

Simulate $N$ samples from $\mathcal{N}(\mu, 1)$ using JAX. Display the generated data.

In [3]:
def simulate_gaussian_data(N, mu, rng_seed=0):
    key = jax.random.PRNGKey(rng_seed)
    X = jax.random.normal(key, shape=(N,)) + mu
    return np.array(X)

## 5. Compute Likelihood, Prior, and Posterior

Compute the log-likelihood of the data as a function of $\mu$, and update the Gaussian prior to obtain the posterior parameters.

**Posterior update equations:**
$$
v' = \left(\frac{N}{1} + \frac{1}{v}\right)^{-1}, \qquad m' = v' \left(\frac{N \bar{x}}{1} + \frac{m}{v}\right)
$$

### Bayesian Inference for Gaussian Mean (Known Variance)

Suppose we observe data $x_1, \ldots, x_N$ drawn i.i.d. from a Gaussian distribution with unknown mean $\mu$ and known variance $\sigma^2 = 1$.

#### **Prior**

We place a Gaussian prior on $\mu$:
$$
p(\mu) = \mathcal{N}(\mu \mid m, v) = \frac{1}{\sqrt{2\pi v}} \exp\left(-\frac{1}{2v}(\mu - m)^2\right)
$$

#### **Likelihood**

The likelihood of the data given $\mu$ is:
$$
p(\mathbf{x} \mid \mu) = \prod_{i=1}^N \mathcal{N}(x_i \mid \mu, 1) = (2\pi)^{-N/2} \exp\left(-\frac{1}{2} \sum_{i=1}^N (x_i - \mu)^2\right)
$$

#### **Posterior Derivation**

By Bayes' rule:
$$
p(\mu \mid \mathbf{x}) \propto p(\mathbf{x} \mid \mu) \, p(\mu)
$$

Plug in the likelihood and prior:
\begin{align*}
\log p(\mu \mid \mathbf{x}) &= \log p(\mathbf{x} \mid \mu) + \log p(\mu) + \text{const} \\
&= -\frac{1}{2} \sum_{i=1}^N (x_i - \mu)^2 - \frac{1}{2v} (\mu - m)^2 + \text{const}
\end{align*}

Expand and collect terms in $\mu$:
\begin{align*}
-\frac{1}{2} \sum_{i=1}^N (x_i - \mu)^2 &= -\frac{1}{2} \left( \sum_{i=1}^N x_i^2 - 2\mu \sum_{i=1}^N x_i + N\mu^2 \right) \\
-\frac{1}{2v} (\mu - m)^2 &= -\frac{1}{2v} (\mu^2 - 2m\mu + m^2)
\end{align*}

Combine quadratic and linear terms in $\mu$:
\begin{align*}
\log p(\mu \mid \mathbf{x}) &= -\frac{N}{2} \mu^2 + \mu \sum_{i=1}^N x_i -\frac{1}{2v} \mu^2 + \frac{m}{v} \mu + \text{const} \\
&= -\frac{1}{2} \left( N + \frac{1}{v} \right) \mu^2 + \left( \sum_{i=1}^N x_i + \frac{m}{v} \right) \mu + \text{const}
\end{align*}

This is the log of a Gaussian in $\mu$, so the posterior is Gaussian:
$$
p(\mu \mid \mathbf{x}) = \mathcal{N}(\mu \mid m', v')
$$
where
$$
v' = \left( N + \frac{1}{v} \right)^{-1}
$$
$$
m' = v' \left( \sum_{i=1}^N x_i + \frac{m}{v} \right ) = v' \left( N \bar{x} + \frac{m}{v} \right )
$$
with $\bar{x} = \frac{1}{N} \sum_{i=1}^N x_i$.

**Why use log-likelihood instead of likelihood?**  
- The likelihood for Gaussian variance involves products of many small probabilities, which can quickly underflow to zero for moderate $N$.
- Taking the logarithm turns products into sums, making computations numerically stable and easier to handle.
- Log-likelihood is also more convenient for optimization and plotting, as it avoids extremely small numbers.

**How to go from log-likelihood to likelihood:**  
- If $\log p(\mathbf{x} \mid \sigma^2)$ is the log-likelihood, then the likelihood is $p(\mathbf{x} \mid \sigma^2) = \exp(\log p(\mathbf{x} \mid \sigma^2))$.
- In practice, for plotting or normalization, we often subtract the maximum log-likelihood before exponentiating:  
    $$
    \text{likelihood}(\sigma^2) \propto \exp\left(\log p(\mathbf{x} \mid \sigma^2) - \max_{\sigma^2} \log p(\mathbf{x} \mid \sigma^2)\right)
    $$
    This keeps the values in a numerically safe range.


#### **Summary:**

- **Prior:** $p(\mu) = \mathcal{N}(\mu \mid m, v)$  
- **Likelihood:** $p(\mathbf{x} \mid \mu) = \prod_{i=1}^N \mathcal{N}(x_i \mid \mu, 1)$  
- **Posterior:** $p(\mu \mid \mathbf{x}) = \mathcal{N}(\mu \mid m', v')$  
    with  
    $$
    v' = \left( N + \frac{1}{v} \right)^{-1}, \qquad m' = v' \left( N \bar{x} + \frac{m}{v} \right )
    $$

In [4]:
def log_likelihood_mu(X, mu_grid):
    # Compute average log-likelihood for each mu in grid
    if len(X) == 0:
        return np.zeros_like(mu_grid)
    N = len(X)
    ll = -0.5 * ((X[:, None] - mu_grid[None, :]) ** 2) - 0.5 * np.log(2 * np.pi)
    return ll.sum(axis=0) / N


def gaussian_posterior_params(m, v, X):
    N = len(X)
    if N == 0:
        return m, v
    xbar = np.mean(X)
    v_post = 1.0 / (N / 1.0 + 1.0 / v)
    m_post = v_post * (N * xbar / 1.0 + m / v)
    return m_post, v_post

## 6. Plot Observed Data and True Distribution

Use Plotly to plot a histogram of the observed data and overlay the true Gaussian density.

In [5]:
def plot_observed_data(X, mu):
    x_grid = np.linspace(-3 + mu, 3 + mu, 250)
    fig = go.Figure()
    if len(X) > 0:
        fig.add_trace(
            go.Histogram(
                x=X,
                nbinsx=30,
                histnorm="probability density",
                name="Observed Data",
                opacity=0.5,
                marker_color="#636EFA",
            )
        )
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=norm.pdf(x_grid, mu, 1.0),
            name="True Gaussian",
            line=dict(color="#EF553B", width=3),
        )
    )
    fig.update_layout(
        title="Observed Data and True Gaussian Density",
        xaxis_title="x",
        yaxis_title="Density",
        width=1000,
        height=400,
        barmode="overlay",
    )
    fig.update_traces(opacity=0.7)
    fig.show()

In [6]:
def plot_observed_data_interactive(N, mu, cp, m, v):
    X = simulate_gaussian_data(N, mu)
    plot_observed_data(X, mu)


out = widgets.interactive_output(
    plot_observed_data_interactive,
    {
        "N": N_slider,
        "mu": mu_slider,
        "cp": cp_checkbox,
        "m": m_slider,
        "v": v_slider,
    },
)
display(ui)
display(out)

VBox(children=(HBox(children=(IntSlider(value=30, description='N (data):', min=1), FloatSlider(value=0.0, desc…

Output()

## 7. Plot Prior, Likelihood, and Posterior for Mean

Plot the prior, likelihood, and posterior distributions for $\mu$ using Plotly. Show how the posterior updates as data and prior parameters change.

In [7]:
def plot_prior_likelihood_posterior(X, mu, cp, m, v):
    mu_grid = jnp.linspace(-3, 3, 300)
    # Likelihood (up to normalization)
    ll = log_likelihood_mu(X, mu_grid)
    likelihood = jnp.exp(ll - jnp.max(ll))
    likelihood /= jnp.trapezoid(likelihood, mu_grid)
    # Prior and posterior
    if cp:
        prior = norm.pdf(mu_grid, m, jnp.sqrt(v))
        m_post, v_post = gaussian_posterior_params(m, v, X)
        posterior = norm.pdf(mu_grid, m_post, jnp.sqrt(v_post))
    else:
        prior = jnp.ones_like(mu_grid)
        posterior = likelihood
    # Plot
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=mu_grid,
            y=prior,
            name="Prior p(μ)",
            line=dict(color="#00CC96", dash="dash"),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=mu_grid,
            y=likelihood,
            name="Likelihood p̂(x|μ)",
            line=dict(color="#636EFA"),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=mu_grid,
            y=posterior,
            name="Posterior p(μ|x)",
            line=dict(color="#EF553B", width=3),
        )
    )
    fig.add_vline(
        x=mu,
        line_color="#222",
        line_dash="dot",
        annotation_text="True μ",
        annotation_position="top",
    )
    if cp:
        lower = m_post - 1.96 * jnp.sqrt(v_post)
        upper = m_post + 1.96 * jnp.sqrt(v_post)
        fig.add_vrect(
            x0=lower,
            x1=upper,
            fillcolor="rgba(239,85,59,0.15)",
            layer="below",
            line_width=0,
            annotation_text="95% CI",
            annotation_position="top right",
        )
    fig.update_layout(
        title="Prior, Likelihood, and Posterior for μ",
        xaxis_title="μ",
        yaxis_title="Density (unnormalized)",
        width=1000,
        height=400,
        legend=dict(orientation="h", yanchor="bottom", y=-0.3, xanchor="center", x=0.5),
    )
    fig.show()

In [8]:
def plot_prior_likelihood_posterior_interactive(N, mu, cp, m, v):
    X = simulate_gaussian_data(N, mu)
    plot_prior_likelihood_posterior(X, mu, cp, m, v)


out2 = widgets.interactive_output(
    plot_prior_likelihood_posterior_interactive,
    {
        "N": N_slider,
        "mu": mu_slider,
        "cp": cp_checkbox,
        "m": m_slider,
        "v": v_slider,
    },
)
display(ui)
display(out2)

VBox(children=(HBox(children=(IntSlider(value=30, description='N (data):', min=1), FloatSlider(value=0.0, desc…

Output()

## 8. Plot Posterior Predictive Distribution

Plot the posterior predictive density for new data points, using the derived formula. Overlay this on the observed data histogram.

### Posterior Predictive Formula (with Detailed Math)

The **posterior predictive distribution** describes the probability of a new data point $x^*$ given the observed data $\mathbf{x}$, after integrating out the uncertainty in the parameter $\mu$:

$$
p(x^* \mid \mathbf{x}) = \int p(x^* \mid \mu) \, p(\mu \mid \mathbf{x}) \, d\mu
$$

For the Gaussian mean inference with known variance ($\sigma^2 = 1$), and a Gaussian prior $p(\mu) = \mathcal{N}(\mu \mid m, v)$, the posterior $p(\mu \mid \mathbf{x})$ is also Gaussian:

$$
p(\mu \mid \mathbf{x}) = \mathcal{N}(\mu \mid m', v')
$$

where

$$
v' = \left( \frac{N}{1} + \frac{1}{v} \right)^{-1}
$$

$$
m' = v' \left( \frac{N \bar{x}}{1} + \frac{m}{v} \right)
$$

The likelihood for a new data point is:

$$
p(x^* \mid \mu) = \mathcal{N}(x^* \mid \mu, 1)
$$

Plugging into the integral:

$$
p(x^* \mid \mathbf{x}) = \int \mathcal{N}(x^* \mid \mu, 1) \, \mathcal{N}(\mu \mid m', v') \, d\mu
$$

This integral is the convolution of two Gaussians, which results in another Gaussian:

$$
p(x^* \mid \mathbf{x}) = \mathcal{N}(x^* \mid m', v' + 1)
$$

**Summary:**  
- The posterior predictive for a new $x^*$ is a Gaussian with mean $m'$ (the posterior mean of $\mu$) and variance $v' + 1$ (posterior variance of $\mu$ plus the known data variance).

In [9]:
def posterior_predictive_density(x_grid, m_post, v_post):
    # Predictive is N(m_post, v_post + 1)
    return norm.pdf(x_grid, m_post, np.sqrt(v_post + 1))


def plot_posterior_predictive(X, m_post, v_post, mu):
    x_grid = np.linspace(-3 + mu, 3 + mu, 250)
    fig = go.Figure()
    if len(X) > 0:
        fig.add_trace(
            go.Histogram(
                x=X,
                nbinsx=30,
                histnorm="probability density",
                name="Observed Data",
                opacity=0.5,
                marker_color="#636EFA",
            )
        )
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=norm.pdf(x_grid, mu, 1.0),
            name="True Gaussian",
            line=dict(color="#EF553B", width=3),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=posterior_predictive_density(x_grid, m_post, v_post),
            name="Posterior Predictive",
            line=dict(color="#00CC96", dash="dash"),
        )
    )
    fig.update_layout(
        title="Posterior Predictive Density vs Observed Data",
        xaxis_title="x",
        yaxis_title="Density",
        width=1000,
        height=400,
        barmode="overlay",
    )
    fig.show()

In [10]:
def plot_posterior_predictive_interactive(N, mu, cp, m, v):
    X = simulate_gaussian_data(N, mu)
    if cp:
        m_post, v_post = gaussian_posterior_params(m, v, X)
        plot_posterior_predictive(X, m_post, v_post, mu)
    else:
        plot_observed_data(X, mu)


out3 = widgets.interactive_output(
    plot_posterior_predictive_interactive,
    {
        "N": N_slider,
        "mu": mu_slider,
        "cp": cp_checkbox,
        "m": m_slider,
        "v": v_slider,
    },
)
display(ui)
display(out3)

VBox(children=(HBox(children=(IntSlider(value=30, description='N (data):', min=1), FloatSlider(value=0.0, desc…

Output()

## 9. Display Summary and Mathematical Explanation

Display a summary of the Bayesian updating process, including the formulas for prior, likelihood, posterior, and predictive distributions. Explain the interpretation of each step.

In [11]:
def display_summary(N, mu, cp, m, v, X):
    if cp:
        m_post, v_post = gaussian_posterior_params(m, v, X)
        summary = f"""
**Bayesian Updating for Gaussian Mean:**

- **Prior:** $\\mu \\sim \\mathcal{{N}}(m, v)$ with $m = {m:.2f}$, $v = {v:.2f}$
- **Data:** $N={N}$, true $\\mu={mu:.2f}$, observed $\\bar{{x}} = {np.mean(X):.2f}$
- **Posterior:** $\\mu \\mid \\mathbf{{x}} \\sim \\mathcal{{N}}(m', v')$ with
  $$
  v' = \\left(\\frac{{N}}{{1}} + \\frac{{1}}{{v}}\\right)^{{-1}} = {v_post:.3f}, \\qquad
  m' = v' \\left(\\frac{{N \\bar{{x}}}}{{1}} + \\frac{{m}}{{v}}\\right) = {m_post:.3f}
  $$
- **Posterior Predictive:** For new $x^*$,
  $$
  p(x^* \\mid \\mathbf{{x}}) = \\mathcal{{N}}(x^* \\mid m', v' + 1)
  $$
"""
    else:
        summary = f"""
**Likelihood-only Inference (No Prior):**

- **Data:** $N={N}$, true $\\mu={mu:.2f}$
- **Posterior:** Proportional to the likelihood.
- **Interpretation:** No prior information is used; inference is based only on observed data.
"""
    display(Markdown(summary))

## 10. Interactive Dashboard for Gaussian Mean Inference

Combine all widgets, plots, and explanations into an interactive dashboard that updates as parameters change.

In [12]:
def gauss_mean_inference_dashboard(N, mu, cp, m, v):
    X = simulate_gaussian_data(N, mu)
    plot_observed_data(X, mu)
    plot_prior_likelihood_posterior(X, mu, cp, m, v)
    if cp:
        m_post, v_post = gaussian_posterior_params(m, v, X)
        plot_posterior_predictive(X, m_post, v_post, mu)
    display_summary(N, mu, cp, m, v, X)


dashboard_out = widgets.interactive_output(
    gauss_mean_inference_dashboard,
    {
        "N": N_slider,
        "mu": mu_slider,
        "cp": cp_checkbox,
        "m": m_slider,
        "v": v_slider,
    },
)
display(ui)
display(dashboard_out)

VBox(children=(HBox(children=(IntSlider(value=30, description='N (data):', min=1), FloatSlider(value=0.0, desc…

Output()