In [1]:
%load_ext autoreload
%autoreload 2

In [5]:
from exponential_families import ExponentialFamily
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp  # Useful for log-likelihood calculations
import jax.scipy.stats as jss  # For standard distributions
import optax
from numpy.typing import ArrayLike
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
import plotly.figure_factory as ff

In [6]:
pio.templates.default = "plotly_white"

# Lecture 05: Exponential Families

Based on the lecture slides by Philipp Hennig (SS 2023).

This notebook explores the concept of Exponential Families, their properties related to sufficient statistics, conjugate priors, and maximum likelihood estimation, with coding examples primarily using JAX.

## The Skeleton of ML and Conjugate Priors

The slides introduce probabilistic inference as a key component of ML.
$$p(w|x) = \frac{p(x|w) p(w)}{\int p(x|w) p(w) dw}$$
For i.i.d. data $x = \{x_1, \dots, x_n\}$, this becomes:
$$p(w|x) = \frac{\prod_{i=1}^n p(x_i|w) p(w)}{\int \prod_{i=1}^n p(x_i|w) p(w) dw}$$

This general form can be complex. The concept of **Conjugate Priors** simplifies Bayesian inference by ensuring the posterior has the same functional form as the prior, with parameters updated based on **sufficient statistics** $\phi(x)$ of the data.

$$p(w|x) \propto l(x; w) g(w; \theta) = g(w; \theta + \phi(x))$$

The power of conjugate priors is that the complex data likelihood $\prod_{i=1}^n p(x_i|w)$ combines with the prior $p(w)$ in a way that the data's influence is entirely captured by the sufficient statistics $\phi(x)$ and the number of data points $n$.

## Exponential Families: Definition

Exponential Families are a class of probability distributions for which conjugate priors naturally exist.

A probability distribution for a random variable $X$ is in the exponential family if its probability density/mass function can be written in the form:
$$p_w(x) = h(x) \exp[\phi(x)^T w - \log Z(w)]$$
or equivalently
$$p_w(x) = \frac{h(x)}{Z(w)} \exp[\phi(x)^T w]$$

Where:
- $h(x)$: the **base measure**. A non-negative function depending only on $x$.
- $\phi(x)$: the **sufficient statistics**. A vector function of the data $x$. It summarizes all the information from the data relevant to the natural parameters $w$.
- $w$: the **natural parameters**. A vector of parameters for the distribution.
- $Z(w)$: the **partition function**. A normalization constant ensuring the distribution integrates/sums to 1. It depends on $w$. $\log Z(w)$ is the log-partition function.

The slides also mention canonical parameters $\theta$, where $w = \eta(\theta)$

### Example: The Poisson Distribution as an Exponential Family

Let's express the probability mass function (PMF) of a Poisson distribution $\mathrm{Poisson}(x; \lambda)$ in the exponential family form.

The standard PMF is:
$$
p(x|\lambda) = \frac{\lambda^x e^{-\lambda}}{x!}, \quad x \in \{0, 1, 2, \dots\}
$$

Here, $x$ represents the observed count (number of events), and $\lambda$ is the rate parameter (expected number of events per interval) of the Poisson distribution.

---

**Example:**  
Suppose you are monitoring the number of emails received by a helpdesk per hour.  
- $x$: The actual number of emails received in a particular hour (e.g., $x = 7$ emails in one hour).
- $\lambda$: The average rate of emails received per hour, estimated from historical data (e.g., $\lambda = 5$ emails/hour).

In this scenario, the Poisson distribution models the probability of observing $x$ emails in an hour, given the average rate $\lambda$.

---

We want to write Poisson distribution in the exponential family form:
$$
p_w(x) = h(x) \exp[\phi(x)^T w - \log Z(w)]
$$

Let's rearrange the PMF:
$$
p(x|\lambda) = \frac{1}{x!} \exp\left(x \log \lambda - \lambda\right)
$$

Now, match terms to the exponential family form:

- **Sufficient Statistic:** The part depending only on $x$ and multiplied by the parameter in the exponent.
    $$
    \phi(x) = x
    $$

- **Natural Parameter:** The parameter that multiplies the sufficient statistic.
    $$
    w = \log \lambda
    $$

- **Base Measure:** The part depending only on $x$ outside the exponential.
    $$
    h(x) = \frac{1}{x!}
    $$

- **Log-Partition Function:** The part that ensures normalization, depending only on the parameter.
    $$
    \log Z(w) = e^{w}
    $$
    since $e^{w} = \lambda$.

So, the Poisson PMF in exponential family form is:
$$
p(x|w) = \frac{1}{x!} \exp\left(x w - e^{w}\right)
$$

**Summary Table:**

| Component            | Expression            |
|----------------------|----------------------|
| Sufficient Statistic | $\phi(x) = x$        |
| Natural Parameter    | $w = \log \lambda$   |
| Base Measure         | $h(x) = 1/x!$        |
| Log-Partition        | $\log Z(w) = e^{w}$  |

In [None]:
class Poisson(ExponentialFamily):
    """The Poisson distribution."""

    def __init__(self) -> None:
        """The Poisson has no fixed parameters."""
        super().__init__()

    def sufficient_statistics(self, x: ArrayLike | jnp.ndarray) -> jnp.ndarray:
        """The sufficient statistics are the identity function."""
        return jnp.asarray(x)

    def log_base_measure(self, x: ArrayLike | jnp.ndarray) -> jnp.ndarray:
        """
        h(x) = 1/x!, thus log h(x) = -log(x!)
        We use gammaln(x + 1) instead of log(x!) for numerical stability and to support non-integer or array inputs.
        gammaln(n) returns log(Gamma(n)), and since Gamma(n+1) = n! for integer n, gammaln(x + 1) = log(x!).
        """
        x = jnp.asarray(x)
        return -jax.scipy.special.gammaln(x[..., 0] + 1)

    def log_partition(self, lambdas: ArrayLike | jnp.ndarray) -> jnp.ndarray:
        """log Z(lambda) = lambda"""
        lambdas = jnp.asarray(lambdas)
        return lambdas[..., 0]

    def parameters_to_natural_parameters(
        self, lambdas: ArrayLike | jnp.ndarray, /
    ) -> jnp.ndarray:
        """eta = log(lambda)."""
        lambdas = jnp.asarray(lambdas)
        return jnp.log(lambdas)

    def conjugate_log_partition(
        self, alpha: ArrayLike | jnp.ndarray, nu: ArrayLike | jnp.ndarray, /
    ) -> jnp.ndarray:
        """
        F(alpha, nu) = Gamma(alpha+1) / nu^(alpha+1)
        log F(alpha, nu) = Gammaln(alpha+1) - (alpha+1) * log(nu)
        """
        return jax.scipy.special.gammaln(alpha + 1) - (alpha + 1) * jnp.log(nu)

In [33]:
likelihood = Poisson()

lambdas = jnp.arange(1, 10)
ks = jnp.arange(0, 20)

# Plot probability of k for different lambdas
fig1 = go.Figure()
lambdas_plot = jnp.arange(1, 10)
for lam in lambdas_plot:
    fig1.add_trace(
        go.Scatter(
            x=ks,
            y=jnp.exp(likelihood.logpdf(ks[..., None], [lam])),
            mode="lines+markers",
            name=f"$\\lambda={lam}$",
            line=dict(color="red"),
            marker=dict(size=6),
            opacity=0.7,
        )
    )
fig1.update_layout(
    xaxis_title="$k$",
    yaxis_title="Probability",
    title="Probability of $k$ for different $\\lambda$",
    yaxis=dict(range=[0, 0.4]),
    template="plotly_white",
)
fig1.update_xaxes(dtick=5, minor=dict(dtick=1))

fig1.show()

# Plot likelihood of lambda for all k
fig2 = go.Figure()
lambdas = jnp.linspace(1, 10, 100)
for k in ks:
    fig2.add_trace(
        go.Scatter(
            x=lambdas,
            y=jnp.exp(likelihood.logpdf([k], lambdas[..., None])),
            mode="lines",
            name=f"$k={k}$",
            line=dict(width=1, color="blue"),
            opacity=0.5,
            showlegend=False if k > 5 else True,  # Only show legend for first few k
        )
    )
fig2.update_layout(
    xaxis_title="$\\lambda$",
    yaxis_title="Likelihood",
    title="Likelihood of $\\lambda$ for different $k$",
    template="plotly_white",
)
fig2.show()

Consider the exponential family 
$$ 
p_w(x \mid w) = h(x) \exp \left[ \phi(x)^\top w - \log Z(w) \right] 
$$

Its conjugate prior is the exponential family 
$$
p_\alpha(w \mid \alpha, \nu) = \exp \left[ \left( \begin{pmatrix} w \\ -\log Z(w) \end{pmatrix}^\top \begin{pmatrix} \alpha \\ \nu \end{pmatrix} \right) - \log F(\alpha, \nu) \right]
$$

with partition function / normalization constant (constructed as below so prior normalize to 1)
$$
F(\alpha, \nu) := \int \exp(\alpha^\top w - \nu \log Z(w)) \, dw 
$$

- $\alpha$ and $\nu$ are the natural parameters of the conjugate prior for the exponential family.
    - $\alpha$ is a vector (or scalar) that plays the same role for the prior as the sufficient statistics $\phi(x)$ do for the likelihood.
    - $\nu$ is a scalar that acts like a "prior sample size" or strength, controlling how much weight the prior has relative to the data.

Posterior is
$$
p_\alpha(w \mid \alpha, \nu) \prod_{i=1}^n p_w(x_i \mid w) \propto p_\alpha \left( w \mid \alpha + \sum \phi(x_i), \nu + n \right) 
$$

And the predictive posterior is

$$
p(x) = \int p_w(x \mid w) p_\alpha(w \mid \alpha, \nu) \, dw = h(x) \int e^{(\phi(x) + \alpha)^\top w + (\nu + 1) \log Z(w) - \log F(\alpha, \nu)} \, dw
$$

$$
= h(x) \frac{F(\phi(x) + \alpha, \nu + 1)}{F(\alpha, \nu)}
$$

The conjugate prior for

$$
p(x; \lambda) = \frac{1}{x!} \exp(x \log \lambda - \lambda)
$$

has the sufficient statistics $\log \lambda$, and log partition function $-\lambda$, and thus is of the form

$$
p(\lambda \mid \alpha, \nu) = \exp((\log \lambda)\alpha - \lambda\nu - \log F(\alpha, \nu)) = \frac{\lambda^\alpha e^{-\lambda\nu}}{F(\alpha, \nu)}
$$
This is a Gamma distribution because its density has the same functional form:
$$
p(\lambda \mid \alpha, \nu) \propto \lambda^\alpha e^{-\lambda \nu}
$$
which matches the standard Gamma distribution (with shape $\alpha+1$ and rate $\nu$):
$$
\mathrm{Gamma}(\lambda; \alpha+1, \nu) = \frac{\nu^{\alpha+1}}{\Gamma(\alpha+1)} \lambda^{\alpha} e^{-\lambda \nu}
$$
So, the conjugate prior for the Poisson rate $\lambda$ is a Gamma distribution, with the normalization constant

$$
F(\alpha, \nu) = \frac{\Gamma(\alpha + 1)}{\nu^{\alpha+1}}
$$



Bayesian inference on Poisson distributed data thus "simply" becomes:


In [72]:
likelihood = Poisson()  # identify the model
prior = (
    likelihood.conjugate_prior()
)  # the model induces the conjugate prior (if we know its partition function)
prior_natural_parameters = [1, 1]  # just pick one particular prior
data = [5]  # let's say we've seen this many events in a certain time interval
posterior = prior  # conjugate inference! So easy!
# we just need to update the _parameters_. The likelihood tells us how to do that:
posterior_natural_parameters = likelihood.posterior_parameters(
    prior_natural_parameters, data
)

lambdas = jnp.linspace(0.1, 10, 100)

fig = go.Figure()
fig.add_trace(
    go.Scatter(
        x=lambdas,
        y=jnp.exp(prior.logpdf(lambdas[..., None], prior_natural_parameters)),
        mode="lines",
        line=dict(dash="dash", color="gray"),
        name="prior",
        opacity=0.9,
    )
)
fig.add_trace(
    go.Scatter(
        x=lambdas,
        y=jnp.exp(likelihood.logpdf(data, lambdas[..., None])),
        mode="lines",
        line=dict(color="blue"),
        name="likelihood",
        opacity=0.9,
    )
)
fig.add_trace(
    go.Scatter(
        x=lambdas,
        y=jnp.exp(posterior.logpdf(lambdas[..., None], posterior_natural_parameters)),
        mode="lines",
        line=dict(color="red"),
        name="posterior",
        opacity=0.9,
    )
)
fig.update_layout(
    xaxis_title="$\\lambda$",
    yaxis_title="Density",
    template="plotly_white",
    legend=dict(title=""),
)
fig.show()