# Lecture 05: Exponential Families

Based on the lecture slides by Philipp Hennig (SS 2023).

This notebook explores the concept of Exponential Families, their properties related to sufficient statistics, conjugate priors, and maximum likelihood estimation, with coding examples primarily using JAX.

## The Skeleton of ML and Conjugate Priors

The slides introduce probabilistic inference as a key component of ML.
$$p(w|x) = \frac{p(x|w) p(w)}{\int p(x|w) p(w) dw}$$
For i.i.d. data $x = \{x_1, \dots, x_n\}$, this becomes:
$$p(w|x) = \frac{\prod_{i=1}^n p(x_i|w) p(w)}{\int \prod_{i=1}^n p(x_i|w) p(w) dw}$$

This general form can be complex. The concept of **Conjugate Priors** simplifies Bayesian inference by ensuring the posterior has the same functional form as the prior, with parameters updated based on **sufficient statistics** $\phi(x)$ of the data.

$$p(w|x) \propto l(x; w) g(w; \theta) = g(w; \theta + \phi(x))$$

The power of conjugate priors is that the complex data likelihood $\prod_{i=1}^n p(x_i|w)$ combines with the prior $p(w)$ in a way that the data's influence is entirely captured by the sufficient statistics $\phi(x)$ and the number of data points $n$.

## Exponential Families: Definition

Exponential Families are a class of probability distributions for which conjugate priors naturally exist.

A probability distribution for a random variable $X$ is in the exponential family if its probability density/mass function can be written in the form:
$$p_w(x) = h(x) \exp[\phi(x)^T w - \log Z(w)]$$
or equivalently
$$p_w(x) = \frac{h(x)}{Z(w)} \exp[\phi(x)^T w]$$

Where:
- $h(x)$: the **base measure**. A non-negative function depending only on $x$.
- $\phi(x)$: the **sufficient statistics**. A vector function of the data $x$. It summarizes all the information from the data relevant to the natural parameters $w$.
- $w$: the **natural parameters**. A vector of parameters for the distribution.
- $Z(w)$: the **partition function**. A normalization constant ensuring the distribution integrates/sums to 1. It depends on $w$. $\log Z(w)$ is the log-partition function.

The slides also mention canonical parameters $\theta$, where $w = \eta(\theta)$


### Example: The Univariate Gaussian Distribution as an Exponential Family

Let's express the probability density function (PDF) of a univariate Gaussian $N(x; \mu, \sigma^2)$ in the exponential family form.

The standard PDF is:
$$p(x | \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)$$

We need to rearrange the exponent to isolate terms that are a linear function of some parameters multiplied by some function of $x$.
$$-\frac{(x - \mu)^2}{2\sigma^2} = -\frac{x^2 - 2\mu x + \mu^2}{2\sigma^2} = \left(\frac{\mu}{\sigma^2}\right) x + \left(-\frac{1}{2\sigma^2}\right) x^2 - \frac{\mu^2}{2\sigma^2}$$

Now, let's match this to the exponential family form $h(x) \exp[\phi(x)^T w - \log Z(w)]$:

- **Sufficient Statistics:** These are the parts depending only on $x$ that are multiplied by the parameters in the exponent. Based on the rearranged exponent, we can choose:
  $$\phi(x) = \begin{bmatrix} x \\ x^2 \end{bmatrix}$$
  *(Note: The slide uses $\phi(x) = [x, -x^2/2]^T$. Let's use the slide's definition for consistency as it directly maps to the natural parameters shown there.)*
  $$\phi(x) = \begin{bmatrix} x \\ -x^2/2 \end{bmatrix}$$

- **Natural Parameters:** These are the parameters that multiply the sufficient statistics in the exponent. Matching the terms:
  $$w = \begin{bmatrix} \mu/\sigma^2 \\ 1/\sigma^2 \end{bmatrix}$$

- **Base Measure:** $h(x)$ is the part depending only on $x$ outside the main exponential term. In the standard Gaussian PDF, there's no explicit $h(x)$ term outside the $\frac{1}{\sqrt{2\pi\sigma^2}}$ and the main exponential. We can effectively set $h(x) = 1$ and absorb the $\frac{1}{\sqrt{2\pi\sigma^2}}$ into the normalization constant $Z(w)$.

- **Log-Partition Function:** $\log Z(w)$ must absorb terms depending only on the parameters $(\mu, \sigma^2)$, or equivalently, $w$. From the original PDF and the rearrangement, the terms that depend only on parameters are $-\log\left(\frac{1}{\sqrt{2\pi\sigma^2}}\right)$ and $-\frac{\mu^2}{2\sigma^2}$.
  $$-\left(-\log\left(\sqrt{2\pi\sigma^2}\right)\right) - \frac{\mu^2}{2\sigma^2} = \log(\sqrt{2\pi\sigma^2}) - \frac{\mu^2}{2\sigma^2}$$
  So, $\log Z(w) = \frac{\mu^2}{2\sigma^2} + \log(\sqrt{2\pi\sigma^2})$.
  Now, express this in terms of $w_1 = \mu/\sigma^2$ and $w_2 = 1/\sigma^2$. This means $\mu = w_1/w_2$ and $\sigma^2 = 1/w_2$.
  $$\log Z(w) = \frac{(w_1/w_2)^2}{2(1/w_2)} + \log\left(\sqrt{\frac{2\pi}{w_2}}\right) = \frac{w_1^2}{2w_2} + \frac{1}{2}\log(2\pi) - \frac{1}{2}\log(w_2)$$
  Often, the constant term $\frac{1}{2}\log(2\pi)$ is omitted in the definition of $\log Z(w)$ as it cancels out when normalizing, but it's needed for the PDF to integrate to 1. Let's include it for completeness when comparing to the standard PDF.

Let's implement the sufficient statistics and log-partition function for the univariate Gaussian using JAX.

In [None]:
import jax
import jax.numpy as jnp
from jax.scipy.special import logsumexp  # Useful for log-likelihood calculations
import jax.scipy.stats as jss  # For standard distributions
import optax


# Define the sufficient statistics function for a single data point x
def gaussian_sufficient_statistics(x):
    """
    Sufficient statistics for the univariate Gaussian, following the slide's definition.
    phi(x) = [x, -x^2/2]
    """
    return jnp.array([x, -0.5 * x**2])


# Define the log-partition function for the univariate Gaussian in terms of natural parameters
def gaussian_log_partition_function(w):
    """
    Log-partition function for the univariate Gaussian in terms of natural parameters.
    w = [w1, w2], where w1 = mu/sigma^2, w2 = 1/sigma^2
    log Z(w) = 0.5 * (w1^2 / w2) - 0.5 * log(w2) + 0.5 * log(2*pi)
    We include the constant here to match the standard log PDF.
    """
    w1, w2 = w
    # Ensure w2 is positive (corresponds to sigma^2 > 0)
    # Adding a small epsilon can improve numerical stability if w2 can be close to zero.
    # For this example, we assume w2 > 0.
    log_Z = 0.5 * (w1**2 / w2) - 0.5 * jnp.log(w2) + 0.5 * jnp.log(2 * jnp.pi)
    return log_Z


# Define the log PDF of the Gaussian in Exponential Family form
# log p_w(x) = log h(x) + phi(x)^T w - log Z(w)
# Assuming h(x) = 1, so log h(x) = 0
def gaussian_log_pdf_ef(x, w):
    """
    Log PDF of the univariate Gaussian in Exponential Family form.
    Assumes h(x) = 1.
    """
    phi_x = gaussian_sufficient_statistics(x)
    log_Z_w = gaussian_log_partition_function(w)
    return jnp.dot(phi_x, w) - log_Z_w


# Example usage:
# Let's pick some natural parameters corresponding to N(x; mu=1, sigma^2=0.5)
# mu = 1.0, sigma^2 = 0.5
# w1 = mu/sigma^2 = 1.0 / 0.5 = 2.0
# w2 = 1/sigma^2 = 1 / 0.5 = 2.0
w_example = jnp.array([2.0, 2.0])

# Calculate log Z for these parameters
log_Z_example = gaussian_log_partition_function(w_example)
print(f"Example Natural Parameters (mu=1, sigma^2=0.5): w = {w_example}")
print(f"Log Partition Function for w: {log_Z_example}")

# Evaluate the log PDF for a data point x=1.0 using the EF form
x_example = 1.0
log_pdf_example_ef = gaussian_log_pdf_ef(x_example, w_example)
print(f"Log PDF at x={x_example} with w={w_example} (EF form): {log_pdf_example_ef}")

# For comparison, the standard log PDF of N(x; mu=1, sigma^2=0.5) at x=1.0
log_pdf_standard = jss.norm.logpdf(x_example, loc=1.0, scale=jnp.sqrt(0.5))
print(f"Standard Log PDF at x={x_example} for N(1, 0.5): {log_pdf_standard}")

# The two log PDF values should match closely, confirming the EF representation.

Example Natural Parameters (mu=1, sigma^2=0.5): w = [2. 2.]
Log Partition Function for w: 1.5723649263381958
Log PDF at x=1.0 with w=[2. 2.] (EF form): -0.5723649263381958
Standard Log PDF at x=1.0 for N(1, 0.5): -0.5723649263381958


## Sufficient Statistics and Data Reduction
A key property of exponential families is that for i.i.d. data $x = \{x_1, \dots, x_n\}$, the joint likelihood is:

$$
p_w(x_1, \dots, x_n \mid w) = \prod_{i=1}^n p_w(x_i \mid w) = \prod_{i=1}^n \left[ h(x_i) \exp\left( \phi(x_i)^T w - \log Z(w) \right) \right] = \left( \prod_{i=1}^n h(x_i) \right) \exp\left( \left( \sum_{i=1}^n \phi(x_i) \right)^T w - n \log Z(w) \right)
$$

This shows that the joint distribution depends on the data only through the sum of the sufficient statistics $\sum_{i=1}^n \phi(x_i)$ and the number of data points $n$. This is the **data reduction property**: instead of needing the full dataset, you only need to compute and store the sum of sufficient statistics.

For the univariate Gaussian, the sum of sufficient statistics is

$$
\sum_{i=1}^n \begin{bmatrix} x_i \\ -\frac{1}{2} x_i^2 \end{bmatrix} = \begin{bmatrix} \sum x_i \\ -\frac{1}{2} \sum x_i^2 \end{bmatrix}
$$

This means for any number of data points, we only need the sum of the data points and the sum of their squares to compute the sufficient statistics for the dataset.

In [2]:
# Generate some synthetic univariate Gaussian data
key = jax.random.PRNGKey(0)
mu_true = 5.0
sigma_true = 2.0
n_samples = 100
synthetic_data = jax.random.normal(key, (n_samples,)) * sigma_true + mu_true

print(f"Generated {n_samples} data points from N(mu={mu_true}, sigma={sigma_true})")
print(f"First 5 data points: {synthetic_data[:5]}")

# Calculate the sufficient statistics for each data point
phi_data = jax.vmap(gaussian_sufficient_statistics)(synthetic_data)

# Calculate the sum of sufficient statistics for the dataset
sum_phi_x = jnp.sum(phi_data, axis=0)

print(f"\nSum of sufficient statistics for the dataset: {sum_phi_x}")

# For univariate Gaussian, sum_phi_x = [sum(x_i), sum(-0.5 * x_i^2)]
# Let's verify this against direct sums:
direct_sum_x = jnp.sum(synthetic_data)
direct_sum_neg_half_x_sq = jnp.sum(-0.5 * synthetic_data**2)

print(f"Direct sum of x_i: {direct_sum_x}")
print(f"Direct sum of -0.5 * x_i^2: {direct_sum_neg_half_x_sq}")
# The values in sum_phi_x should match the direct sums.
# This confirms that the sum of sufficient statistics captures the necessary information.

Generated 100 data points from N(mu=5.0, sigma=2.0)
First 5 data points: [8.245284  9.0505295 4.132811  4.8427653 5.352182 ]

Sum of sufficient statistics for the dataset: [  521.8606 -1538.9651]
Direct sum of x_i: 521.860595703125
Direct sum of -0.5 * x_i^2: -1538.965087890625


## Conjugate Prior Parameter Update
The slides state that if the likelihood $p_w(x \mid w)$ is in an exponential family, its conjugate prior $p_\alpha(w \mid \alpha, \nu)$ is also in a related exponential family form. The key result is the simple update rule for the prior's parameters $(\alpha, \nu)$ to obtain the posterior's parameters $(\alpha', \nu')$ after observing data $x = \{x_1, \ldots, x_n\}$:

$$
\alpha' = \alpha + \sum_{i=1}^n \phi(x_i)
$$

$$
\nu' = \nu + n
$$

Here, $\alpha$ is a vector parameter with the same dimension as $\phi(x)$, and $\nu$ is a scalar parameter representing the "prior number of observations".

In [3]:
# Conceptual Conjugate Prior Update
# For the Gaussian likelihood, the conjugate prior over (mu, sigma^2) is the Normal-Gamma distribution.
# Its parameters are related to the (alpha, nu) parameters of the EF conjugate prior form.
# Let's assume some initial prior parameters (alpha_0, nu_0)
# For simplicity, let's use values that might represent a weak prior.
alpha_0 = jnp.array([0.1, 0.1])  # Prior "pseudo-sufficient statistics"
nu_0 = 1.0  # Prior "pseudo-observation count"

print(f"Initial prior parameters: alpha_0 = {alpha_0}, nu_0 = {nu_0}")

# We observed the 'n_samples' synthetic data points with sum of sufficient statistics 'sum_phi_x'

# The updated posterior parameters are:
alpha_posterior = alpha_0 + sum_phi_x
nu_posterior = nu_0 + n_samples

print(f"Observed data: n = {n_samples}, sum_phi_x = {sum_phi_x}")
print(
    f"Posterior parameters: alpha_posterior = {alpha_posterior}, nu_posterior = {nu_posterior}"
)

# These posterior parameters (alpha_posterior, nu_posterior) completely define the Normal-Gamma posterior distribution over the natural parameters w (which relate to mu and sigma^2).
# The complexity of the data has been reduced to a simple addition to the prior parameters.

Initial prior parameters: alpha_0 = [0.1 0.1], nu_0 = 1.0
Observed data: n = 100, sum_phi_x = [  521.8606 -1538.9651]
Posterior parameters: alpha_posterior = [  521.9606 -1538.8651], nu_posterior = 101.0


## Maximum Likelihood Estimation (MLE) in Exponential Families
Exponential Families simplify Maximum Likelihood Estimation. The log-likelihood for i.i.d. data is:

$$
\log p_w(x_1, \ldots, x_n \mid w) = \left(\sum_{i=1}^n \phi(x_i)\right)^T w - n \log Z(w) + \sum_{i=1}^n \log h(x_i)
$$

To find the MLE $\hat{w}$, we take the gradient with respect to $w$ and set it to zero:

$$
\nabla_w \log p(x \mid w) = \sum_{i=1}^n \phi(x_i) - n \nabla_w \log Z(w) = 0
$$

This gives the crucial property:

$$
\nabla_w \log Z(w) = \frac{1}{n} \sum_{i=1}^n \phi(x_i)
$$

The gradient of the log-partition function with respect to the natural parameters is equal to the empirical average of the sufficient statistics from the data at the MLE $\hat{w}$.

The MLE for $w$ is found by solving this equation.

**Example: MLE for the Gaussian**

For the univariate Gaussian, we found

$$
\log Z(w) = \frac{w_1^2}{2w_2} - \frac{1}{2} \log(w_2) + \text{const}.
$$

The gradient with respect to $w = [w_1, w_2]^T$ is:

$$
\nabla_w \log Z(w) = \left[
    \frac{\partial}{\partial w_1},\ \frac{\partial}{\partial w_2}
\right] \left( \frac{w_1^2}{2w_2} - \frac{1}{2} \log w_2 \right)
= \left[ \frac{w_1}{w_2},\ -\frac{w_1^2}{2w_2^2} - \frac{1}{2w_2} \right]
$$

(Ignoring the constant term as its gradient is zero.)

The empirical mean of the sufficient statistics is

$$
\frac{1}{n} \sum_{i=1}^n \phi(x_i) = \frac{1}{n} \sum_{i=1}^n \begin{bmatrix} x_i \\ -\frac{1}{2} x_i^2 \end{bmatrix}
= \begin{bmatrix} \bar{x} \\ -\frac{1}{2} \overline{x^2} \end{bmatrix}
$$

Setting $\nabla_w \log Z(\hat{w}) = \frac{1}{n} \sum \phi(x_i)$:

$$
\frac{\hat{w}_1}{\hat{w}_2} = \bar{x} \\
-\frac{\hat{w}_1^2}{2\hat{w}_2^2} - \frac{1}{2\hat{w}_2} = -\frac{1}{2} \overline{x^2}
$$

Solving these equations for $\hat{w}_1$ and $\hat{w}_2$ in terms of $\bar{x}$ and $\overline{x^2}$ gives the MLE for the natural parameters. As shown in the slides, this leads to

$$
\hat{w}_1 = \frac{\bar{x}}{\overline{x^2} - \bar{x}^2}, \qquad \hat{w}_2 = \frac{1}{\overline{x^2} - \bar{x}^2}
$$

which correspond to the standard MLEs for $\mu$ and $\sigma^2$.

We can use JAX's automatic differentiation to compute the gradient of our `gaussian_log_partition_function` and verify this property.

Certainly! Here’s an explanation of the analytic solution for Maximum Likelihood Estimation (MLE) of the univariate Gaussian in its standard (normal) form:

---

### Analytic Solution of Gaussian MLE (Normal Form)

Suppose you have $n$ i.i.d. data points $x_1, x_2, \ldots, x_n$ drawn from a Gaussian (normal) distribution with unknown mean $\mu$ and variance $\sigma^2$:

$$
p(x \mid \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left(-\frac{(x - \mu)^2}{2\sigma^2}\right)
$$

#### Log-Likelihood

The log-likelihood for the

### Moment Matching vs. Direct Maximization: Convergence Speed and Stability

**Moment Matching (Solving $\nabla_w \log Z(w) = \text{target}$):**
- **Pros:**  
    - For exponential families with simple $\log Z(w)$, moment matching can be very efficient, especially if an analytic solution exists (e.g., Gaussian).
    - The optimization objective is often well-behaved (quadratic for Gaussian), leading to fast and stable convergence.
- **Cons:**  
    - For complex or high-dimensional models, the root-finding problem may be non-convex or require careful initialization.
    - Requires computing the gradient of $\log Z(w)$, which may be expensive for some distributions.

**Direct Maximization (Maximizing Log Likelihood):**
- **Pros:**  
    - More general: works for any model where you can write the likelihood, not just exponential families.
    - The negative log-likelihood is often convex for exponential families, leading to stable optimization.
- **Cons:**  
    - May converge slower if the objective is poorly scaled or if the optimizer is not well-tuned.
    - Still requires evaluating $\log Z(w)$, which can be costly for some models.

**In Practice:**
- For exponential families, **both methods are equivalent** and typically converge quickly and stably if gradients are implemented correctly.
- **Moment matching** can be slightly faster if the gradient of $\log Z(w)$ is easy to compute and the target is well-scaled.
- **Direct maximization** is more flexible and often preferred in practice, especially when extending to non-exponential family models.

**Summary Table:**

| Method             | Speed      | Stability | Generality |
|--------------------|------------|-----------|------------|
| Moment Matching    | Fast (EF)  | High      | Limited    |
| Direct Maximization| Fast       | High      | General    |

For the univariate Gaussian example in this notebook, **both methods are fast and stable**. For more complex models, direct maximization is usually preferred for its flexibility.

In [4]:
# Compute the gradient of the log-partition function using JAX's autograd
grad_log_Z = jax.grad(
    gaussian_log_partition_function
)  # Let's evaluate the gradient at the natural parameters corresponding to the true mean and variance

# mu_true = 5.0, sigma_true = 2.0, sigma_true^2 = 4.0
# w1_true = mu_true / sigma_true^2 = 5.0 / 4.0 = 1.25
# w2_true = 1 / sigma_true^2 = 1 / 4.0 = 0.25
w_true = jnp.array([1.25, 0.25])

gradient_at_true_w = grad_log_Z(w_true)

print(f"True Natural Parameters: w = {w_true}")
print(f"Gradient of Log Z evaluated at true w: {gradient_at_true_w}")

# According to theory, this gradient should be equal to the expected value of the sufficient statistics
# under the distribution N(mu_true, sigma_true^2).
# E[phi(x)] = E[[x, -x^2/2]] = [E[x], -0.5 * E[x^2]]
# E[x] = mu_true = 5.0
# E[x^2] = Var(x) + (E[x])^2 = sigma_true^2 + mu_true^2 = 4.0 + 5.0**2 = 4.0 + 25.0 = 29.0
# E[-x^2/2] = -0.5 * E[x^2] = -0.5 * 29.0 = -14.5

expected_phi_true_w = jnp.array([mu_true, -0.5 * (sigma_true**2 + mu_true**2)])
print(f"Expected value of Sufficient Statistics at true w: {expected_phi_true_w}")

# The gradient evaluated at the true natural parameters should match the expected value of the sufficient statistics.

# Now, let's find the empirical mean of sufficient statistics from our synthetic data
empirical_mean_phi = sum_phi_x / n_samples
print(
    f"\nEmpirical Mean of Sufficient Statistics from synthetic data: {empirical_mean_phi}"
)

# The MLE property states that grad_log_Z(w_hat_mle) = empirical_mean_phi.
# We can find the w_hat_mle that satisfies this equation.
# For the Gaussian, we derived the analytical form of w_hat_mle from empirical stats:
mean_x = empirical_mean_phi[0]
# Recover mean(x^2) from mean(-0.5 * x^2)
mean_x_sq = -2 * empirical_mean_phi[1]

# Calculate MLE natural parameters from empirical moments
# w2_hat_mle = 1 / (mean(x^2) - mean(x)^2)
# w1_hat_mle = mean(x) * w2_hat_mle
w2_hat_mle = 1.0 / (mean_x_sq - mean_x**2)
w1_hat_mle = mean_x * w2_hat_mle
w_hat_mle = jnp.array([w1_hat_mle, w2_hat_mle])

print(f"MLE Natural Parameters estimated from empirical stats: w_hat_mle = {w_hat_mle}")

# Let's verify if the gradient of log Z at w_hat_mle is close to the empirical mean of sufficient statistics
gradient_at_mle_w = grad_log_Z(w_hat_mle)
print(f"Gradient of Log Z evaluated at w_hat_mle: {gradient_at_mle_w}")

# These two values (empirical_mean_phi and gradient_at_mle_w) should be very close,
# confirming the MLE property for Exponential Families computationally.

True Natural Parameters: w = [1.25 0.25]
Gradient of Log Z evaluated at true w: [  5.  -14.5]
Expected value of Sufficient Statistics at true w: [  5.  -14.5]

Empirical Mean of Sufficient Statistics from synthetic data: [  5.218606 -15.389651]
MLE Natural Parameters estimated from empirical stats: w_hat_mle = [1.4719148  0.28205132]
Gradient of Log Z evaluated at w_hat_mle: [  5.218606 -15.389652]


### Analytic MLE Solution for the Univariate Gaussian (Normal Form)

Given a set of independent and identically distributed (i.i.d.) samples $\{x_1, x_2, \dots, x_n\}$ drawn from a univariate Gaussian distribution:

$$
x_i \sim \mathcal{N}(\mu, \sigma^2)
$$

The probability density function (PDF) of the normal distribution is:

$$
p(x_i \mid \mu, \sigma^2) = \frac{1}{\sqrt{2\pi\sigma^2}} \exp\left( -\frac{(x_i - \mu)^2}{2\sigma^2} \right)
$$

## Log-Likelihood Function

The likelihood of the data is the product of the individual probabilities:

$$
L(\mu, \sigma^2) = \prod_{i=1}^n p(x_i \mid \mu, \sigma^2)
$$

Taking the natural logarithm gives the log-likelihood:

$$
\log L(\mu, \sigma^2) = -\frac{n}{2} \log(2\pi\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2
$$

## MLE for $\mu$ and $\sigma^2$

To find the MLEs, we take partial derivatives of the log-likelihood with respect to $\mu$ and $\sigma^2$, and set them to zero.

### Estimate for $\mu$:

$$
\frac{\partial}{\partial \mu} \log L = \frac{1}{\sigma^2} \sum_{i=1}^n (x_i - \mu) = 0
$$

Solving for $\mu$:

$$
\hat{\mu}_{\text{MLE}} = \frac{1}{n} \sum_{i=1}^n x_i
$$
### Estimate for $\sigma^2$:

$$
\frac{\partial}{\partial \sigma^2} \log L = -\frac{n}{2\sigma^2} + \frac{1}{2(\sigma^2)^2} \sum_{i=1}^n (x_i - \mu)^2 = 0
$$

Solving for $\sigma^2$:

$$
\hat{\sigma}^2_{\text{MLE}} = \frac{1}{n} \sum_{i=1}^n (x_i - \hat{\mu})^2
$$


In [19]:
# Compute MLE estimates for mu and sigma^2 from the synthetic data
mu_mle = jnp.mean(synthetic_data)
sigma2_mle = jnp.mean((synthetic_data - mu_mle) ** 2)
sigma_mle = jnp.sqrt(sigma2_mle)

print(f"MLE mu: {mu_mle}")
print(f"MLE sigma^2: {sigma2_mle}")
print(f"MLE sigma: {sigma_mle}")

# Compute natural parameters from these MLE estimates
w1_from_mle = mu_mle / sigma2_mle
w2_from_mle = 1.0 / sigma2_mle
w_from_mle = jnp.array([w1_from_mle, w2_from_mle])

print(f"Natural parameters from (mu, sigma^2): w = {w_from_mle}")

# Compare to w_hat_mle computed via exponential family moment matching
print(f"w_hat_mle from exponential family: {w_hat_mle}")

# Show the difference
diff = w_from_mle - w_hat_mle
print(f"Difference (w_from_mle - w_hat_mle): {diff}")

MLE mu: 5.218605995178223
MLE sigma^2: 3.5454607009887695
MLE sigma: 1.882939338684082
Natural parameters from (mu, sigma^2): w = [1.4719119 0.2820508]
w_hat_mle from exponential family: [1.4719148  0.28205132]
Difference (w_from_mle - w_hat_mle): [-2.861023e-06 -5.364418e-07]


To numerically infer the natural parameters (MLE) for an exponential family, you solve the equation:

$$
\nabla_w \log Z(w) = \frac{1}{n} \sum_{i=1}^n \phi(x_i)
$$

This is a root-finding or optimization problem. The typical approach is:

1. **Define a loss/objective function** that measures the squared difference between $\nabla_w \log Z(w)$ and the empirical mean of sufficient statistics.
2. **Use a numerical optimizer** (e.g., gradient descent, Adam) to minimize this loss with respect to $w$.

Example code:

In [47]:
from functools import partial

learning_rate = 1e-2
optimizer = optax.adam(learning_rate=learning_rate)


@partial(jax.jit, static_argnames=["objective_fn"])
def step(w, opt_state, objective_fn):
    loss, grads = jax.value_and_grad(objective_fn)(w)
    updates, opt_state = optimizer.update(grads, opt_state, w)
    w = optax.apply_updates(w, updates)
    return w, opt_state, loss

In [63]:
from tqdm import tqdm


# empirical_mean_phi is already computed
def mle_objective(w):
    grad = grad_log_Z(w)
    return jnp.sum((grad - empirical_mean_phi) ** 2)


# Use an optimizer (e.g., Adam) to minimize mle_objective
w_guess = jnp.array([0.0, 1.0])
opt_state = optimizer.init(w_guess)


for i in tqdm(range(10000)):
    w_guess, opt_state, loss = step(w_guess, opt_state, mle_objective)
    if i % 100 == 0:
        print(f"Iteration {i}: w_guess = {w_guess}, loss = {loss}")

    if loss < 1e-6:
        print(f"Converged at iteration {i} with loss {loss}")
        break

print("\nNumerical MLE (w_guess after optimization):", w_guess)
print("Analytic MLE (w_hat_mle):", w_hat_mle)
print("Difference (numerical - analytic):", w_guess - w_hat_mle)


 96%|█████████▌| 9596/10000 [00:00<00:00, 47043.45it/s]

Iteration 0: w_guess = [0.00999993 0.99000007], loss = 248.93556213378906
Iteration 100: w_guess = [0.8779475 0.1792744], loss = 0.8200246691703796
Iteration 200: w_guess = [0.8870315  0.17669371], loss = 0.041085150092840195
Iteration 300: w_guess = [0.8895079  0.17716952], loss = 0.04048626869916916
Iteration 400: w_guess = [0.8925052  0.17770821], loss = 0.03981725499033928
Iteration 500: w_guess = [0.8959408 0.1783258], loss = 0.03906245157122612
Iteration 600: w_guess = [0.89977425 0.17901497], loss = 0.038234904408454895
Iteration 700: w_guess = [0.9039749  0.17977014], loss = 0.037345465272665024
Iteration 800: w_guess = [0.90851647 0.18058664], loss = 0.03640415146946907
Iteration 900: w_guess = [0.91337705 0.18146057], loss = 0.03541946038603783
Iteration 1000: w_guess = [0.9185372  0.18238838], loss = 0.03439943119883537
Iteration 1100: w_guess = [0.9239797 0.183367 ], loss = 0.03335116058588028
Iteration 1200: w_guess = [0.9296882  0.18439354], loss = 0.03228117898106575
Ite




### Direct Maximization of the Log Likelihood

Instead of solving for the MLE using the moment-matching property of exponential families, you can maximize the log likelihood directly with respect to the natural parameters $w$:

$$
\log p_w(x_1, \ldots, x_n) = \left(\sum_{i=1}^n \phi(x_i)\right)^T w - n \log Z(w)
$$

#### Numerical Optimization Approach

1. **Define the negative log likelihood (NLL) function** to minimize:
    $$
    \text{NLL}(w) = -\left(\sum_{i=1}^n \phi(x_i)\right)^T w + n \log Z(w)
    $$
2. **Use a numerical optimizer** (e.g., Adam, gradient descent) to minimize NLL with respect to $w$.
3. **JAX's autograd** can compute gradients automatically.

#### Optimization Loop

You can use the same optimizer setup as before, but now minimizing `neg_log_likelihood(w)`. The optimizer will iteratively update $w$ to maximize the log likelihood.

**Summary:**  
Direct maximization of the log likelihood is equivalent to the MLE approach, but uses the explicit likelihood function as the objective. For exponential families, both approaches yield the same result, but direct maximization is more general and works for any model where you can write down the likelihood.

In [59]:
def neg_log_likelihood(w):
    return -jnp.dot(sum_phi_x, w) + n_samples * gaussian_log_partition_function(w)


# Use an optimizer (e.g., Adam) to minimize mle_objective
w_guess = jnp.array([0.0, 1.0])
opt_state = optimizer.init(w_guess)


for i in tqdm(range(2000)):
    w_guess, opt_state, loss = step(w_guess, opt_state, neg_log_likelihood)
    if i % 100 == 0:
        print(f"Iteration {i}: w_guess = {w_guess}, loss = {loss}")

print("\nNumerical MLE (w_guess after optimization):", w_guess)
print("Analytic MLE (w_hat_mle):", w_hat_mle)
print("Difference (numerical - analytic):", w_guess - w_hat_mle)


100%|██████████| 2000/2000 [00:00<00:00, 28902.12it/s]

Iteration 0: w_guess = [0.00999993 0.99000007], loss = 1630.85888671875
Iteration 100: w_guess = [0.8192815  0.14025278], loss = 217.92494201660156
Iteration 200: w_guess = [0.8784846  0.17459245], loss = 210.4298553466797
Iteration 300: w_guess = [0.9520907 0.188031 ], loss = 209.02294921875
Iteration 400: w_guess = [1.0243574  0.20108947], loss = 207.907470703125
Iteration 500: w_guess = [1.0917753  0.21327691], loss = 207.07298278808594
Iteration 600: w_guess = [1.1528786  0.22432601], loss = 206.46908569335938
Iteration 700: w_guess = [1.2071865  0.23414855], loss = 206.0419158935547
Iteration 800: w_guess = [1.2547268  0.24274847], loss = 205.74549865722656
Iteration 900: w_guess = [1.2957942  0.25017852], loss = 205.5435028076172
Iteration 1000: w_guess = [1.3308265  0.25651732], loss = 205.40838623046875
Iteration 1100: w_guess = [1.3603371  0.26185745], loss = 205.31979370117188
Iteration 1200: w_guess = [1.3848749  0.26629803], loss = 205.26300048828125
Iteration 1300: w_guess




### Comparing MLE and Bayesian Inference in the Exponential Family

**Maximum Likelihood Estimation (MLE):**
- MLE finds the parameter $w$ that maximizes the likelihood of the observed data.
- For exponential families, the MLE solution is where the gradient of the log-partition function matches the empirical mean of the sufficient statistics:
    $$
    \nabla_w \log Z(\hat{w}_{\text{MLE}}) = \frac{1}{n} \sum_{i=1}^n \phi(x_i)
    $$
- In this notebook, `w_hat_mle` is computed directly from the data, ignoring any prior information.

**Bayesian Inference:**
- Bayesian inference combines the likelihood with a prior to produce a posterior distribution over $w$.
- For exponential families with conjugate priors, the posterior is in the same family, with updated parameters:
    $$
    \alpha' = \alpha + \sum_{i=1}^n \phi(x_i), \quad \nu' = \nu + n
    $$
- The mode of the posterior (MAP estimate) is found by solving:
    $$
    \nabla_w \log Z(w^*) = \frac{\alpha + \sum_{i=1}^n \phi(x_i)}{\nu + n}
    $$
- In this notebook, `w_optimized` is the posterior mode, which incorporates both the prior (`alpha_0`, `nu_0`) and the data.

**Key Differences:**
- **MLE** uses only the data; it can overfit if $n$ is small.
- **Bayesian inference** incorporates prior beliefs, leading to more regularized estimates, especially with limited data.
- As $n$ increases, the influence of the prior diminishes, and the Bayesian posterior mode approaches the MLE.

**In summary:**  
- `w_hat_mle` is the MLE estimate (data only).
- `w_optimized` is the Bayesian posterior mode (prior + data).
- For large datasets, both estimates become similar; for small datasets, the Bayesian approach is more robust due to the prior.

Consider the exponential family $ p_w(x \mid w) = h(x) \exp \left[ \phi(x)^\top w - \log Z(w) \right] $
- its conjugate prior is the exponential family 

$$
p_\alpha(w \mid \alpha, \nu) = \exp \left[ \left( \begin{pmatrix} w \\ -\log Z(w) \end{pmatrix}^\top \begin{pmatrix} \alpha \\ \nu \end{pmatrix} \right) - \log F(\alpha, \nu) \right]
$$

with partition function
$$
F(\alpha, \nu) := \int \exp(\alpha^\top w - \nu \log Z(w)) \, dw 
$$


and the predictive posterior is

$$
p(x) = \int p_w(x \mid w) p_\alpha(w \mid \alpha, \nu) \, dw = h(x) \int e^{(\phi(x) + \alpha)^\top w + (\nu + 1) \log Z(w) - \log F(\alpha, \nu)} \, dw
$$

$$
= h(x) \frac{F(\phi(x) + \alpha, \nu + 1)}{F(\alpha, \nu)}
$$


## Laplace Approximation (when F is intractable)
The slides mention that computing the normalization constant $F(\alpha, \nu)$ for the conjugate prior can be difficult or impossible in closed form for some exponential families. When the full analytic posterior is intractable, we can use approximation methods.

**Laplace Approximation** is one such method. It approximates the posterior distribution $p(w|x)$ with a Gaussian distribution centered at the mode of the true posterior.

To perform Laplace approximation:

1. **Find the mode $\hat{w}$ of the posterior $p(w|x)$.**  
  This is done by finding the value of $w$ where the gradient of the log posterior is zero:  
  $$\nabla_w \log p(w|x) = 0.$$
  As derived in the slides and the previous section, setting the gradient of the log posterior to zero is equivalent to solving the root-finding problem:
  $$
  \nabla_w \log Z(w^*) = \frac{\alpha + \sum_i \phi(x_i)}{\nu + n}
  $$
  where $\alpha$ and $\nu$ are the prior parameters, and $\sum \phi(x_i)$ and $n$ are from the data.

2. **Evaluate the Hessian (the matrix of second partial derivatives) of the negative log posterior at the mode $\hat{w}$.**  
  Let this Hessian be
  $$
  \Psi = -\nabla_w \nabla_w^T \log p(w|x) \Big|_{w = \hat{w}}
  $$

3. **Approximate the posterior as a Gaussian distribution**  
  $$
  N(w; \hat{w}, \Psi^{-1})
  $$
  The mode $\hat{w}$ is the mean, and the inverse of the negative Hessian $\Psi^{-1}$ is the covariance matrix of the approximating Gaussian.

The root-finding problem in step 1 can be solved using numerical optimization. We can define an objective function that measures how far $\nabla_w \log Z(w)$ is from the target value $\frac{\alpha + \sum \phi(x_i)}{\nu + n}$ and minimize it.

In [None]:
# Demonstrate finding the mode of the log posterior using optimization
# This is equivalent to solving the root-finding problem: grad_log_Z(w_hat) = target

# The target value for the gradient of log Z is (alpha + sum_phi_x) / (nu + n)
# Let's use the posterior parameters we calculated earlier.
target_grad_log_Z_for_mode = (alpha_posterior) / (nu_posterior)
# Note: In the slide's equation for the mode, alpha and nu are prior parameters.
# The target is (alpha_prior + sum_phi_x) / (nu_prior + n).
# Let's use alpha_0 and nu_0 from our conceptual example for clarity.
target_grad_log_Z_for_mode = (alpha_0 + sum_phi_x) / (nu_0 + n_samples)


print(
    f"\nTarget gradient value to find the posterior mode: {target_grad_log_Z_for_mode}"
)

# We want to find w_hat such that grad_log_Z(w_hat) is close to target_grad_log_Z_for_mode.
# We can minimize the squared difference: ||grad_log_Z(w) - target_grad_log_Z_for_mode||^2


def objective_fn_for_mode(w, target):
    """
    Objective function to minimize to find w_hat (posterior mode).
    We want grad_log_Z(w) to be equal to target.
    """
    grad_at_w = grad_log_Z(w)
    return jnp.sum((grad_at_w - target) ** 2)



# We can use an optimizer from JAX's optax library to find the w that minimizes this objective.
# Set up the optimizer (e.g., Adam is a good general-purpose optimizer)
optimizer = optax.adam(learning_rate=0.1)

# Initialize optimizer state with an initial guess for w
# Ensure the initial guess for w2 is positive.
initial_w_guess = jnp.array([0.0, 1.0])
opt_state = optimizer.init(initial_w_guess)


# Define the training step using JAX's jit for performance
@jax.jit
def train_step_mode_finder(w, opt_state, target):
    # Calculate the loss and its gradient with respect to w
    loss, grads = jax.value_and_grad(objective_fn_for_mode)(w, target)
    # Update the parameters using the optimizer
    updates, opt_state = optimizer.update(grads, opt_state, w)
    w = optax.apply_updates(w, updates)
    return w, opt_state, loss


# Run the optimization loop
num_iterations = 2000  # More iterations might be needed for convergence
w_optimized = initial_w_guess
losses = []

print("Optimizing to find posterior mode...")
for step in range(num_iterations):
    w_optimized, opt_state, loss = train_step_mode_finder(
        w_optimized, opt_state, target_grad_log_Z_for_mode
    )
    losses.append(loss)
    if step % 200 == 0:
        print(f"Step {step}, Loss: {loss:.6f}, Current w: {w_optimized}")

print(f"Optimization finished. Found w_hat (posterior mode): {w_optimized}")
print(f"Final Loss: {losses[-1]:.6f}")

# In a full Laplace approximation, you would then compute the Hessian of the
# negative log posterior at this found mode w_optimized.
# The negative log posterior is proportional to -log p(w|x) = - (alpha + sum_phi_x)^T w + (nu + n) log Z(w) + const
# The Hessian of -log p(w|x) w.r.t w is (nu + n) * Hessian of log Z(w) w.r.t w.
# You would use jax.hessian to compute this at w_optimized.
# The covariance of the Laplace approximation would be the inverse of this Hessian.


Target gradient value to find the posterior mode: [  5.1679263 -15.236288 ]
Optimizing to find posterior mode...
Step 0, Loss: 243.865646, Current w: [0.09999933 0.90000063]
Step 200, Loss: 0.004169, Current w: [1.6672059  0.31928024]
Step 400, Loss: 0.003533, Current w: [1.6666081  0.31888303]
Step 600, Loss: 0.003507, Current w: [1.6652656  0.31863964]
Step 800, Loss: 0.003474, Current w: [1.6635906  0.31833604]
Step 1000, Loss: 0.003434, Current w: [1.661587  0.3179728]
Step 1200, Loss: 0.003388, Current w: [1.6592479  0.31754875]
Step 1400, Loss: 0.003335, Current w: [1.6565611  0.31706166]
Step 1600, Loss: 0.003276, Current w: [1.6535076  0.31650817]
Step 1800, Loss: 0.003209, Current w: [1.6500671  0.31588453]
Optimization finished. Found w_hat (posterior mode): [1.6462348  0.31518978]
Final Loss: 0.003135


## TODO: Review Laplace approximation

In [64]:
# Laplace Approximation: Compute the covariance at the posterior mode

# The negative log posterior (up to a constant) is:
# -log p(w|x) = - (alpha_0 + sum_phi_x)^T w + (nu_0 + n_samples) * log Z(w)
# The Hessian of -log p(w|x) w.r.t. w is (nu_0 + n_samples) * Hessian of log Z(w) at w_optimized

# Compute the Hessian of log Z(w) using JAX
hessian_log_Z = jax.hessian(gaussian_log_partition_function)

# Evaluate the Hessian at the posterior mode
hess_at_mode = hessian_log_Z(w_optimized)

# Compute the precision matrix (negative Hessian of log posterior)
precision = (nu_0 + n_samples) * hess_at_mode

# Covariance is the inverse of the precision matrix
cov_laplace = jnp.linalg.inv(precision)

print("Posterior mode (w_optimized):", w_optimized)
print("Laplace covariance matrix at mode:\n", cov_laplace)

Posterior mode (w_optimized): [1.6462348  0.31518978]
Laplace covariance matrix at mode:
 [[0.05678565 0.01027475]
 [0.01027475 0.00196721]]


In [65]:
from scipy.stats import norm

# The Laplace approximation gives a Gaussian posterior: N(w_optimized, cov_laplace)
# 95% credible interval for each parameter: mean ± 1.96 * std

w_mean = w_optimized
w_std = jnp.sqrt(jnp.diag(cov_laplace))
z = 1.96  # for 95% interval

lower = w_mean - z * w_std
upper = w_mean + z * w_std

for i, (l, m, u) in enumerate(zip(lower, w_mean, upper)):
    print(
        f"Parameter w[{i}]: 95% credible interval = [{l:.4f}, {u:.4f}], mean = {m:.4f}, std = {w_std[i]:.4f}"
    )

Parameter w[0]: 95% credible interval = [1.1792, 2.1133], mean = 1.6462, std = 0.2383
Parameter w[1]: 95% credible interval = [0.2283, 0.4021], mean = 0.3152, std = 0.0444


## Summary

This notebook explored the fundamental concepts of Exponential Families as presented in the lecture slides:
- The connection between probabilistic inference, sufficient statistics, and conjugate priors.
- The definition of an Exponential Family and its components ($h(x)$, $\phi(x)$, $w$, $Z(w)$).
- How the univariate Gaussian fits into the Exponential Family framework.
- The data reduction property provided by sufficient statistics.
- The simple parameter update rule for conjugate priors.
- The crucial property relating the gradient of the log-partition function to the empirical mean of sufficient statistics at the Maximum Likelihood Estimate.
- How numerical optimization can be used to find the mode of the posterior distribution, a key step in Laplace approximation when the conjugate prior normalization constant is intractable.

By working through these examples, you should gain a deeper understanding of why Exponential Families are so important in probabilistic modeling and how their