# Gamma-Gaussian Bayesian Inference: Interactive Visualization

This notebook demonstrates Bayesian inference for the variance (or precision) of a Gaussian distribution using a Gamma prior. We visualize the prior, likelihood, posterior, and posterior predictive distributions with interactive controls for prior and data parameters.

## 1. Import Required Libraries

Import `numpy`, `jax`, `jax.numpy`, `scipy.stats`, `scipy.special`, `plotly.graph_objects`, `plotly.subplots`, `ipywidgets`, and `IPython.display` for computation, simulation, interactivity, and visualization.

In [None]:
import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.stats import gamma, norm
from jax.scipy.special import gamma as gamma_func
from scipy.stats import gamma as scipy_gamma
import tensorflow_probability.substrates.jax as tfp
import plotly.graph_objects as go
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, Markdown, HTML

pio.templates.default = "plotly_white"
# For LaTeX rendering in Jupyter
display(
    HTML(
        '<script type="text/javascript" async src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-MML-AM_SVG"></script>'
    )
)

## 2. Mathematical Background: Gamma-Gaussian Conjugacy

We consider Bayesian inference for the variance (or precision) of a Gaussian distribution with known mean (assume $\mu=0$ for simplicity).

### Likelihood

Given $N$ i.i.d. samples $x_1, \ldots, x_N \sim \mathcal{N}(0, \sigma^2)$, the likelihood for $\sigma^2$ (or precision $\tau = 1/\sigma^2$) is:
$$
p(\mathbf{x} \mid \sigma^2) = \prod_{i=1}^N \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{x_i^2}{2\sigma^2}\right)
$$

### Prior: Gamma on Precision

We place a Gamma prior on the precision $\tau = 1/\sigma^2$:
$$
p(\tau) = \mathrm{Gamma}(\tau \mid \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} \tau^{\alpha-1} e^{-\beta \tau}
$$

### Posterior

By conjugacy, the posterior is also Gamma:
$$
p(\tau \mid \mathbf{x}) = \mathrm{Gamma}(\tau \mid \alpha', \beta')
$$
where
$$
\alpha' = \alpha + \frac{N}{2}, \qquad \beta' = \beta + \frac{1}{2} \sum_{i=1}^N x_i^2
$$

### Posterior Predictive

The predictive density for a new $x^*$ is:
$$
p(x^* \mid \mathbf{x}) = \int p(x^* \mid \tau) p(\tau \mid \mathbf{x}) d\tau
$$
which is a scaled Student-$t$ distribution.

For $\mu=0$:
$$
p(x^* \mid \mathbf{x}) = \frac{\Gamma(\alpha'+0.5)}{\Gamma(\alpha')}\frac{1}{\sqrt{2\pi\beta'}} \left(1 + \frac{(x^*)^2}{2\beta'}\right)^{-(\alpha'+0.5)}
$$

## 3. Set Parameters (Interactive Controls)

Use the sliders below to set the number of data points $N$, the true standard deviation $\sigma$, and (optionally) the Gamma prior parameters $\alpha$ and $\beta$. Toggle the conjugate prior on/off.

In [None]:
# Data sliders
N_slider = widgets.IntSlider(value=30, min=1, max=100, step=1, description="N (data):")
sigma_slider = widgets.FloatSlider(
    value=1.0, min=0.01, max=5.0, step=0.05, description="σ (true):"
)

# Conjugate prior toggle
cp_checkbox = widgets.Checkbox(value=False, description="Conjugate Prior?")

# Prior parameter sliders
alpha_slider = widgets.FloatSlider(
    value=1.0, min=0.01, max=3.0, step=0.05, description="α (prior):"
)
beta_slider = widgets.FloatSlider(
    value=1.0, min=0.01, max=3.0, step=0.05, description="β (prior):"
)


def prior_box(cp):
    if cp:
        return widgets.HBox([alpha_slider, beta_slider])
    else:
        return widgets.HTML("")


ui = widgets.VBox(
    [
        widgets.HBox([N_slider, sigma_slider, alpha_slider, beta_slider]),
        cp_checkbox,
        widgets.interactive_output(prior_box, {"cp": cp_checkbox}),
    ]
)
display(ui)

## 4. Simulate Gaussian Data

Simulate $N$ samples from $\mathcal{N}(0, \sigma^2)$ using JAX. Display the generated data.

In [None]:
def simulate_gaussian_data(N, sigma, rng_seed=0):
    key = jax.random.PRNGKey(rng_seed)
    X = jax.random.normal(key, shape=(N,)) * sigma
    return X

## 5. Compute Likelihood and Posterior Parameters

Compute the log-likelihood of the data as a function of $\sigma$, and update the Gamma prior to obtain the posterior parameters.

**Posterior update equations:**
$$
\alpha' = \alpha + \frac{N}{2}, \qquad \beta' = \beta + \frac{1}{2} \sum_{i=1}^N x_i^2
$$

### Likelihood and posterior

Given a Gamma **prior** on the precision $\tau = 1/\sigma^2$:
$$
p(\tau) = \mathrm{Gamma}(\tau \mid \alpha, \beta) = \frac{\beta^\alpha}{\Gamma(\alpha)} \tau^{\alpha-1} e^{-\beta \tau}
$$
- **$\alpha$ (shape parameter):** Controls the concentration of the prior. Larger $\alpha$ makes the prior more peaked, expressing stronger prior belief about the precision. Smaller $\alpha$ makes the prior more diffuse, representing weaker prior information.

- **$\beta$ (rate parameter):** Controls the scale of the prior. Larger $\beta$ shifts the prior toward lower precision (higher variance), while smaller $\beta$ shifts it toward higher precision (lower variance).

Together, $\alpha$ and $\beta$ encode your prior beliefs about the likely values of the precision $\tau$ before seeing any data.

**Typical values:**  
- $\alpha = 1$, $\beta = 1$ yields an exponential prior (uninformative for precision).
- $\alpha < 1$ and/or $\beta < 1$ gives a very diffuse prior, expressing high uncertainty.
- $\alpha > 1$, $\beta > 1$ makes the prior more concentrated, encoding stronger prior beliefs.
- In practice, $\alpha$ and $\beta$ are often set based on prior knowledge or chosen to be weakly informative (e.g., $\alpha = 1$, $\beta = 0.1$).  
- For a non-informative (Jeffreys) prior on variance, $\alpha \to 0$, $\beta \to 0$ (improper prior).  
- In this notebook, you can interactively adjust $\alpha$ and $\beta$ to see their effect.

**Likelihood function:**  
Given observed data $\mathbf{x} = (x_1, \ldots, x_N)$ and known mean $\mu=0$, the likelihood for variance $\sigma^2$ (or precision $\tau$) is:
$$
p(\mathbf{x} \mid \sigma^2) = \prod_{i=1}^N \frac{1}{\sqrt{2\pi \sigma^2}} \exp\left(-\frac{x_i^2}{2\sigma^2}\right)
$$
or, in terms of precision $\tau = 1/\sigma^2$:
$$
p(\mathbf{x} \mid \tau) = (2\pi)^{-N/2} \tau^{N/2} \exp\left(-\frac{\tau}{2} \sum_{i=1}^N x_i^2\right)
$$

- The likelihood expresses how probable the observed data are for different values of $\sigma^2$ (or $\tau$).
- It is maximized at the sample variance, and becomes more peaked as $N$ increases.  
- In Bayesian inference, the likelihood is combined with the prior to form the posterior.

**Why use log-likelihood instead of likelihood?**  
- The likelihood for Gaussian variance involves products of many small probabilities, which can quickly underflow to zero for moderate $N$.
- Taking the logarithm turns products into sums, making computations numerically stable and easier to handle.
- Log-likelihood is also more convenient for optimization and plotting, as it avoids extremely small numbers.

**How to go from log-likelihood to likelihood:**  
- If $\log p(\mathbf{x} \mid \sigma^2)$ is the log-likelihood, then the likelihood is $p(\mathbf{x} \mid \sigma^2) = \exp(\log p(\mathbf{x} \mid \sigma^2))$.
- In practice, for plotting or normalization, we often subtract the maximum log-likelihood before exponentiating:  
    $$
    \text{likelihood}(\sigma^2) \propto \exp\left(\log p(\mathbf{x} \mid \sigma^2) - \max_{\sigma^2} \log p(\mathbf{x} \mid \sigma^2)\right)
    $$
    This keeps the values in a numerically safe range.

After observing data $\mathbf{x} = (x_1, \ldots, x_N)$, the **posterior** is:
$$
p(\tau \mid \mathbf{x}) = \mathrm{Gamma}(\tau \mid \alpha', \beta')
$$
where
$$
\alpha' = \alpha + \frac{N}{2}, \qquad \beta' = \beta + \frac{1}{2} \sum_{i=1}^N x_i^2
$$

### Derivation: From Likelihood × Prior to Posterior (Gamma-Gaussian Conjugacy)
#### Posterior (by Bayes' rule, up to normalization)
$$
p(\tau \mid \mathbf{x}) \propto p(\mathbf{x} \mid \tau) \, p(\tau)
$$
Plug in the likelihood and prior:
$$
p(\tau \mid \mathbf{x}) \propto \tau^{N/2} \exp\left(-\frac{\tau}{2} S\right) \cdot \tau^{\alpha-1} e^{-\beta \tau}
$$
where $S = \sum_{i=1}^N x_i^2$.

Combine exponents:
$$
p(\tau \mid \mathbf{x}) \propto \tau^{\alpha-1 + N/2} \exp\left(-\left(\beta + \frac{S}{2}\right)\tau\right)
$$

#### Recognize as a Gamma distribution
This is the unnormalized form of a Gamma:
$$
\mathrm{Gamma}(\tau \mid \alpha', \beta') \propto \tau^{\alpha'-1} e^{-\beta' \tau}
$$
where:
$$
\alpha' = \alpha + \frac{N}{2}, \qquad \beta' = \beta + \frac{1}{2} S
$$

#### Conclusion
$$
\boxed{
p(\tau \mid \mathbf{x}) = \mathrm{Gamma}\left(\tau \mid \alpha + \frac{N}{2}, \; \beta + \frac{1}{2} \sum_{i=1}^N x_i^2 \right)
}
$$

**Summary:**  
Multiplying the likelihood and prior gives a new Gamma distribution for the posterior over precision $\tau$, with updated parameters reflecting both prior information and observed data.

In [None]:
def log_likelihood_sigma(X, sigma_grid):
    # Compute average log-likelihood for each sigma in grid
    if len(X) == 0:
        return jnp.zeros_like(sigma_grid)
    N = len(X)
    ll = -0.5 * ((X[:, None] ** 2) / (sigma_grid[None, :] ** 2)) - jnp.log(
        sigma_grid[None, :]
    )
    return ll.sum(axis=0) / N


def gamma_posterior_params(alpha, beta, X):
    N = len(X)
    sumsq = jnp.sum(X**2)
    alpha_post = alpha + N / 2
    beta_post = beta + 0.5 * sumsq
    return alpha_post, beta_post

## 6. Plot Observed Data and True Distribution

Use Plotly to plot a histogram of the observed data and overlay the true Gaussian density.

In [None]:
def plot_observed_data(X, sigma):
    x_grid = jnp.linspace(-3 * sigma, 3 * sigma, 250)
    fig = go.Figure()
    if len(X) > 0:
        fig.add_trace(
            go.Histogram(
                x=X,
                nbinsx=30,
                histnorm="probability density",
                name="Observed Data",
                opacity=0.5,
                marker_color="#636EFA",
            )
        )
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=norm.pdf(x_grid, 0, sigma),
            name="True Gaussian",
            line=dict(color="#EF553B", width=3),
        )
    )
    fig.update_layout(
        title="Observed Data and True Gaussian Density",
        xaxis_title="x",
        yaxis_title="Density",
        width=1200,
        height=500,
        barmode="overlay",
    )
    fig.update_traces(opacity=0.7)
    fig.show()

In [None]:
def plot_observed_data_interactive(N, sigma, cp, alpha, beta):
    X = simulate_gaussian_data(N, sigma)
    plot_observed_data(X, sigma)


out = widgets.interactive_output(
    plot_observed_data_interactive,
    {
        "N": N_slider,
        "sigma": sigma_slider,
        "cp": cp_checkbox,
        "alpha": alpha_slider,
        "beta": beta_slider,
    },
)

display(ui)
display(out)

## 7. Plot Prior, Likelihood, and Posterior for Precision/Variance

Plot the prior, likelihood, and posterior distributions for $\sigma$ (or precision $\tau = 1/\sigma^2$) using Plotly. Show how the posterior updates as data and prior parameters change.

In [None]:
def plot_prior_likelihood_posterior(X, sigma, cp, alpha, beta):
    sigma_grid = jnp.linspace(0.01, 4.0, 300)
    tau_grid = 1.0 / (sigma_grid**2)
    # Likelihood (up to normalization)
    ll = log_likelihood_sigma(X, sigma_grid)
    likelihood = jnp.exp(ll - np.max(ll))
    likelihood /= jnp.trapezoid(likelihood, sigma_grid)
    # Prior and posterior
    if cp:
        prior = gamma.pdf(tau_grid, a=alpha, scale=1 / beta) / (2 * sigma_grid**3)
        alpha_post, beta_post = gamma_posterior_params(alpha, beta, X)
        posterior = gamma.pdf(tau_grid, a=alpha_post, scale=1 / beta_post) / (
            2 * sigma_grid**3
        )
    else:
        prior = jnp.ones_like(sigma_grid)
        posterior = likelihood
    # Plot
    fig = go.Figure()
    fig.add_trace(
        go.Scatter(
            x=sigma_grid,
            y=prior,
            name="Prior p(σ)",
            line=dict(color="#00CC96", dash="dash"),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=sigma_grid,
            y=likelihood,
            name="Likelihood p_hat(x|σ)",
            line=dict(color="#636EFA"),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=sigma_grid,
            y=posterior,
            name="Posterior p(σ|x)",
            line=dict(color="#EF553B", width=3),
        )
    )
    if cp:
        gamma_dist = tfp.distributions.Gamma(concentration=alpha_post, rate=beta_post)
        # Compute 95% credible interval for prior over sigma
        # Compute 95% credible interval for posterior over sigma
        # Posterior over tau: Gamma(alpha_post, beta_post)
        # 95% CI for tau: [tau_low, tau_high]
        tau_low = gamma_dist.quantile(0.025)
        tau_high = gamma_dist.quantile(0.975)
        # tau_low = scipy_gamma.ppf(0.025, alpha_post, scale=1 / beta_post)
        # tau_high = scipy_gamma.ppf(0.975, alpha_post, scale=1 / beta_post)
        # Convert to sigma: sigma = 1 / sqrt(tau)
        sigma_low = 1 / np.sqrt(tau_high)
        sigma_high = 1 / np.sqrt(tau_low)
        fig.add_vrect(
            x0=sigma_low,
            x1=sigma_high,
            fillcolor="rgba(0,204,150,0.15)",
            line_width=0,
            annotation_text="95% CI",
            annotation_position="top right",
            annotation=dict(font=dict(size=12, color="#00CC96")),
            layer="below",
        )
    fig.add_vline(
        x=sigma,
        line_color="#222",
        line_dash="dot",
        annotation_text="True σ",
        annotation_position="top",
    )
    fig.update_layout(
        title="Prior, Likelihood, and Posterior for σ",
        xaxis_title="σ",
        yaxis_title="Density (unnormalized)",
        width=1200,
        height=500,
        legend=dict(orientation="h", yanchor="bottom", y=-0.3, xanchor="center", x=0.5),
    )
    fig.show()

In [None]:
def plot_prior_likelihood_posterior_interactive(N, sigma, cp, alpha, beta):
    X = simulate_gaussian_data(N, sigma)
    plot_prior_likelihood_posterior(X, sigma, cp, alpha, beta)


out2 = widgets.interactive_output(
    plot_prior_likelihood_posterior_interactive,
    {
        "N": N_slider,
        "sigma": sigma_slider,
        "cp": cp_checkbox,
        "alpha": alpha_slider,
        "beta": beta_slider,
    },
)
display(ui)
display(out2)

## 8. Plot Posterior Predictive Distribution

Plot the posterior predictive density for new data points, using the derived formula. Overlay this on the observed data histogram.

### Posterior Predictive Distribution Formulae
The posterior predictive for a new $x^*$ is:
$$
p(x^* \mid \mathbf{x}) = \int p(x^* \mid \tau) p(\tau \mid \mathbf{x}) d\tau
$$
which evaluates to:
$$
p(x^* \mid \mathbf{x}) = \frac{\Gamma(\alpha'+0.5)}{\Gamma(\alpha')} \frac{1}{\sqrt{2\pi\beta'}} \left(1 + \frac{(x^*)^2}{2\beta'}\right)^{-(\alpha'+0.5)}
$$

**Posterior Predictive** in this context refers to the probability distribution for a new data point $x^*$, given the observed data $\mathbf{x}$ and after updating our beliefs about the model parameters (here, the variance or precision) using Bayes' rule.

- It is **not** a point prediction, but a full probability distribution for possible new observations.
- It incorporates both the uncertainty in the model parameters (from the posterior) and the randomness of the data-generating process.
- Mathematically, it is:
  $$
  p(x^* \mid \mathbf{x}) = \int p(x^* \mid \tau) \, p(\tau \mid \mathbf{x}) \, d\tau
  $$
  where $p(\tau \mid \mathbf{x})$ is the posterior over precision, and $p(x^* \mid \tau)$ is the likelihood for a new $x^*$ given $\tau$.

**Interpretation:**  
The posterior predictive gives the likelihood (density) of observing new values $x^*$, taking into account both the observed data and the prior. It reflects our updated uncertainty about the model parameters after seeing the data.

In [None]:
def posterior_predictive_density(x_grid, alpha_post, beta_post):
    # Student-t-like predictive for N(0, sigma^2) with Gamma prior on precision
    coeff = (
        1
        / jnp.sqrt(beta_post * 2 * np.pi)
        * gamma_func(alpha_post + 0.5)
        / gamma_func(alpha_post)
    )
    return coeff * (1 + x_grid**2 / (2 * beta_post)) ** (-alpha_post - 0.5)


def plot_posterior_predictive(X, alpha_post, beta_post, sigma):
    x_grid = jnp.linspace(-3 * sigma, 3 * sigma, 250)
    fig = go.Figure()
    if len(X) > 0:
        fig.add_trace(
            go.Histogram(
                x=X,
                nbinsx=30,
                histnorm="probability density",
                name="Observed Data",
                opacity=0.5,
                marker_color="#636EFA",
            )
        )
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=norm.pdf(x_grid, 0, sigma),
            name="True Gaussian",
            line=dict(color="#EF553B", width=3),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=x_grid,
            y=posterior_predictive_density(x_grid, alpha_post, beta_post),
            name="Posterior Predictive",
            line=dict(color="#00CC96", dash="dash"),
        )
    )

    fig.update_layout(
        title="Posterior Predictive Density vs Observed Data",
        xaxis_title="x",
        yaxis_title="Density",
        width=1200,
        height=500,
        barmode="overlay",
    )
    fig.show()

In [None]:
# Plot the interactive posterior predictive distribution using current widget values
def plot_posterior_predictive_interactive(N, sigma, cp, alpha, beta):
    X = simulate_gaussian_data(N, sigma)
    if cp:
        alpha_post, beta_post = gamma_posterior_params(alpha, beta, X)
        plot_posterior_predictive(X, alpha_post, beta_post, sigma)
    else:
        # If not using conjugate prior, just plot observed data and true Gaussian
        plot_observed_data(X, sigma)


out3 = widgets.interactive_output(
    plot_posterior_predictive_interactive,
    {
        "N": N_slider,
        "sigma": sigma_slider,
        "cp": cp_checkbox,
        "alpha": alpha_slider,
        "beta": beta_slider,
    },
)
display(ui)
display(out3)

## 9. Display Summary and Mathematical Explanation

Display a summary of the Bayesian updating process, including the formulas for prior, likelihood, posterior, and predictive distributions. Explain the interpretation of each step.

In [None]:
def display_summary(N, sigma, cp, alpha, beta, X):
    if cp:
        alpha_post, beta_post = gamma_posterior_params(alpha, beta, X)
        summary = f"""
**Bayesian Updating for Gaussian Variance (Precision):**

- **Prior:** $\\tau = 1/\\sigma^2 \\sim \\mathrm{{Gamma}}(\\alpha, \\beta)$ with $\\alpha = {alpha:.2f}$, $\\beta = {beta:.2f}$
- **Data:** $N={N}$, true $\\sigma={sigma:.2f}$, observed $\\sum x_i^2 = {np.sum(X**2):.2f}$
- **Posterior:** $\\tau \\mid \\mathbf{{x}} \\sim \\mathrm{{Gamma}}(\\alpha', \\beta')$ with
  $$
  \\alpha' = \\alpha + \\frac{{N}}{{2}} = {alpha_post:.2f}, \\qquad
  \\beta' = \\beta + \\frac{{1}}{{2}} \\sum x_i^2 = {beta_post:.2f}
  $$
- **Posterior Predictive:** For new $x^*$,
  $$
  p(x^* \\mid \\mathbf{{x}}) = \\frac{{\\Gamma(\\alpha'+0.5)}}{{\\Gamma(\\alpha')}} \\frac{{1}}{{\\sqrt{{2\\pi\\beta'}}}} \\left(1 + \\frac{{(x^*)^2}}{{2\\beta'}}\\right)^{{-(\\alpha'+0.5)}}
  $$
"""
    else:
        summary = f"""
**Likelihood-only Inference (No Prior):**

- **Data:** $N={N}$, true $\\sigma={sigma:.2f}$
- **Posterior:** Proportional to the likelihood.
- **Interpretation:** No prior information is used; inference is based only on observed data.
"""
    display(Markdown(summary))

## 10. Interactive Dashboard for Gamma Inference

Combine all widgets, plots, and explanations into an interactive dashboard that updates as parameters change.

In [None]:
def gamma_inference_dashboard(N, sigma, cp, alpha, beta):
    # Simulate data
    X = simulate_gaussian_data(N, sigma)
    # Plots
    plot_observed_data(X, sigma)
    plot_prior_likelihood_posterior(X, sigma, cp, alpha, beta)
    if cp:
        alpha_post, beta_post = gamma_posterior_params(alpha, beta, X)
        plot_posterior_predictive(X, alpha_post, beta_post, sigma)
    display_summary(N, sigma, cp, alpha, beta, X)


dashboard_out = widgets.interactive_output(
    gamma_inference_dashboard,
    {
        "N": N_slider,
        "sigma": sigma_slider,
        "cp": cp_checkbox,
        "alpha": alpha_slider,
        "beta": beta_slider,
    },
)
display(ui)
display(dashboard_out)