# Exponential family

## Properties of exponential family distributions

We will need
- Formulation
- Natural parameter
- Whether they keep product and sum
- Definition of conjugate

# EM algorithm

## Motivation

Suppose that we have sample $X$, that comes from a distribution with density $p(\cdot)$, parametrized by (unknown) parameters $\theta$ that we'd like to estimate from sample:
$$
X = (x_1, x_2, \ldots, x_n) \sim p(x | \theta)
$$
$$
p(X|\theta) = \prod_{i=1}^N p(x_i|\theta) \to \max_\theta
$$

What should we do if:

- $p(x | \theta)$ is $\mathcal{N}(\mu, \sigma)$

- $p(x | \theta)$ is from exponential family $p(x | \theta) = \frac{f(x)}{g(\theta)} \exp \left( \theta^\top u(x) \right)$

- $p(x | \theta)$ is not from exponential family

## Motivation

If $p(x | \theta)$ is not from exponential family, we can insert **latent variables** $z$ into our distribution, so that $p(x | z , \theta)$ is from exponential family.

### Example: mixture models
$$
p(x|\theta) = \sum\limits_{k=1}^K \alpha_k p_k(x, \theta_k)
$$
You can verify that $p(x|\theta)$ does not belong to exponential family.

Let's insert variables $z$, such that
- $z_k \in \{0, 1\}$
- $\sum_k z_k = 1$
- $q(z_k = 1) = \alpha_k$

Then,
$$
p(x ,z|\theta) = \prod\limits_{k=1}^K \left( p_k(x, \theta_k) \right)^{z_k}
$$

You can verify that $p(x|z,\theta)$ belongs to exponential family with natural parameter $\sum_k z_k \theta_k$.

## Derivation

Suppose that we have sample $X$, that follow the distribution with density $p(\cdot)$, parametrized by (unknown) parameters $\theta$ that we'd like to estimate from sample:
$$
X = (x_1, x_2, \ldots, x_n) \sim p(x | \theta)
$$
$$
p(X|\theta) = \prod_{i=1}^N p(x_i|\theta) \to \max_\theta
$$

Note that $p(x_i|\theta)$ is not from exponential family. Therefore, we'll be using latent variables $Z$ that follow the distribution $q(\cdot)$, such that $p(x_i, z_i|\theta)$ is from exponential family:
$$
Z = (z_1, z_2, \ldots, z_n) \sim q(z)
$$

## Derivation

$$
L = \log p(x|\theta) = \log p(x|\theta) \cdot \int q(z) \rm{d} z = \int q(z) \log p(x|\theta) \rm{d} z
$$

Now use full probability formula $p(x, z|\theta) = p(z|x,\theta) p(x|\theta)$:
$$
L = \int q(z) \log p(x|\theta) \rm{d} z = \int q(z) \log \frac{p(x, z|\theta)}{p(z|x,\theta)} \rm{d} z = \int q(z) \log \frac{p(x, z|\theta) q(z)}{p(z|x,\theta) q(z)} \rm{d} z
$$

Now let's use some properties of $\log$:
$$
L = \int q(z) \log \frac{p(x, z|\theta) q(z)}{p(z|x,\theta) q(z)} \rm{d} z = \int q(z) \left( \log \frac{p(x, z|\theta)}{q(z)} + \log \frac{q(z)}{p(z|x,\theta)} \right) \rm{d} z
$$

Finally use the linearity of the integral:
$$
L = \int q(z) \log \frac{p(x, z|\theta)}{q(z)} \rm{d} z + \underbrace{\int q(z) \log \frac{q(z)}{p(z|x,\theta)} \rm{d} z}_{?}
$$

## KL divergence

$$
D_{\rm{KL}}(p || q) \equiv KL(p || q) = \int p(x) \log \frac{p(x)}{q(x)} \rm{d} x
$$

Properties:
- $KL(p || q) \neq KL(q || p)$
- $KL(p || q) \geqslant 0$ (prove)

## Derivation


Overall,

$$
L = \log p(x|\theta) = \int q(z) \log \frac{p(x, z|\theta)}{q(z)} \rm{d} z + KL(q(z)||p(z|x, \theta)) \geqslant \int q(z) \log \frac{p(x, z|\theta)}{q(z)} \rm{d} z
$$

This quantity is called **variational lower bound**
$$
\mathcal{L}(q, \theta) = \int q(z) \log \frac{p(x, z|\theta)}{q(z)} \rm{d} z
$$

We will transform our problem into $\mathcal{L}(q, \theta) \to \max_{q,\theta}$. We will be solving this problem using **coordinate descent**, i.e. successively maximize along the two directions:
1. $q^\ast = \arg\max_{q} \mathcal{L}(q, \theta^\ast)$ (**E-step**)
2. $\theta^\ast = \arg\max_{\theta} \mathcal{L}(q^\ast, \theta)$ (**M-step**)

## Tricks

### E-step

Let's recall that $q(\cdot)$ was not present in the original likelihood, therefore $\partial L / \partial q \equiv 0$.

Also recall that at some point in derivation, we had the following equality: $L = \mathcal{L}(q, \theta) + KL(q(z)||p(z|x, \theta))$.

Therefore, maximizing $\mathcal{L}(q, \theta)$ w.r.t. $q$ is equivalent to minimizing $KL(q(z)||p(z|x, \theta))$ w.r.t. $q$!

Think, where does KL-divergence achieve its minimum?
$$
KL(p || q) = \int p(x) \log \frac{p(x)}{q(x)} \rm{d} x
$$

$$
\arg\min_{p} KL(p||q) = q
$$
Therefore we have an exact solution for E-step (one limitation is obvious, does anyone notice?):
$$
q^\ast = \arg\max_{q} \mathcal{L}(q, \theta^\ast) = p(z|x, \theta)
$$

## Tricks

### M-step

$$
\begin{aligned}
\arg\max_\theta \mathcal{L}(q^\ast, \theta) & = \arg\max_\theta \int q^\ast(z) \log \frac{p(x, z|\theta)}{q^\ast(z)} \rm{d} z = \\
& = \arg\max_\theta \left( \int q^\ast(z) \log p(x, z|\theta) \rm{d} z - \int q^\ast(z) \log q^\ast(z) \rm{d} z \right) = \\
& = \arg\max_\theta \int q^\ast(z) \log p(x, z|\theta) \rm{d} z
\end{aligned}
$$
$$
\mathcal{L}(q^\ast, \theta) = \int q^\ast(z) \log p(x, z|\theta) \rm{d} z = \int p(z|x, \theta) \log p(x, z|\theta) \rm{d} z = \mathbb{E}_{p(z|x, \theta)} \log p(x, z|\theta)
$$

## Tricks

### M-step

For the full sample we'll have
$$
\mathcal{L}(q^\ast, \theta) = \sum_{i=1}^N \mathbb{E}_{p(z|x, \theta)} \log p(x, z|\theta)
$$

Which is impractical for large datasets.

Solution:
- Use Monte-Carlo estimation of mean and 
- Stochastic gradient

$$
\theta_{t+1} = \theta_t + \eta_t \cdot n \cdot \nabla_\theta \log p(x_i, z_i|\theta)
$$

## Tricks
### Final algorithm

Iterate until convergence:
1. $q(z_i) = p(z_i|x_i, \theta)$
2. $\theta_{t+1} = \theta_t + \eta_t \cdot n \cdot \nabla_\theta \mathbb{E}_{p(z_i|x_i, \theta)} \log p(x_i, z_i|\theta)$

# Code

## Problem

Consider two coins, A and B, with different probabilities of success $\theta_A$ and $\theta_B$. The experiment is as follows: we randomly choose a coin, then flip it $n$ times and record the number of successes.

If we recorded which coin we used for each sample, we have complete information and can estimate $\theta_A$ and $\theta_B$ in closed form.

- What is the probabilistic model of this experiment?

$$
X \sim Be\left(\frac12\right) Bi(\theta_A, n) + Be\left(\frac12\right) Bi(\theta_B, n)
$$

- What are the MLE estimators for $\theta_A$ and $\theta_B$?

$$
\theta^{\rm{MLE}}_A = \frac{\text{number of successes for A}}{\text{number of trails for A}}
$$

In [1]:
import numpy as np
import scipy.stats as sts

In [2]:
n = 1000

theta_A = 0.8
theta_B = 0.35

theta_true = [theta_A, theta_B]

coin_A = sts.bernoulli(theta_A)
coin_B = sts.bernoulli(theta_B)

coins = [coin_A, coin_B]

In [3]:
zs = np.array([0, 0, 1, 0, 1])
zs_bool = zs.astype(bool)
xs = np.array([coins[coin].rvs(n).sum() for coin in zs])

In [4]:
ml_A = xs[~zs_bool].sum() / (3 * n)
ml_B = xs[zs_bool].sum() / (2 * n)
ml_A, ml_B

(0.7956666666666666, 0.3545)

## Problem

Consider two coins, A and B, with different probabilities of success $\theta_A$ and $\theta_B$. The experiment is as follows: we randomly choose a coin, then flip it $n$ times and record the number of successes and failures.

But if we don't record the coin we used, we have missing data and the problem of estimating $\theta$ is harder to solve. One way to solve it is to use EM algorithm.

We add latent variable $w$ representing the probability of a sample being generated from coin A. Then we will look at the numbers of samples by coin A as:
$$
\#A = w \sum_i x_i
$$

Denote $X = \sum_i x_i$. Likelihood of the model is:
$$
p(X|w, \theta) = \prod_{i=0}^n p_0^{w X} p_1^{(1 - w) X}
$$

Prior distribution is:
$$
q(w| \theta) = Be(w)
$$

The posterior distribution of $w$ is:
$$
p(w|X, \theta) = \frac{p(X|w, \theta)q(w| \theta)}{\sum_{w} p(X|w, \theta)q(w| \theta)} = \frac{p(X|w, \theta)}{\sum_{w} p(X|w, \theta)}
$$

So, E-step is to set $w = q(w| \theta) = p(w|X, \theta)$. The M-step is to set $\theta$ as MLE under fixed $w$, so $p_0$ is the average of the samples with $w$ and $p_1$ is the average of the samples $(1-w)$.

In [5]:
def em(xs, thetas, max_iter=100, tol=1e-6):
    """Expectation-maximization for coin sample problem."""

    ll_old = -np.infty
    for i in range(max_iter):
        ll = np.array([np.sum(xs * np.log(theta), axis=1) for theta in thetas])
        lik = np.exp(ll)
        # E-step
        ws = lik/lik.sum(0)
        # M-step
        vs = np.array([w[:, None] * xs for w in ws])
        thetas = np.array([v.sum(0)/v.sum() for v in vs])
        ll_new = np.sum([w*l for w, l in zip(ws, ll)])
        if np.abs(ll_new - ll_old) < tol:
            break
        ll_old = ll_new
    return i, thetas, ll_new

In [6]:
np.random.seed(1234)

n = 100
p0 = 0.8 # 0.51
p1 = 0.7 # 0.53
xs = np.concatenate([np.random.binomial(n, p0, int(n/2)), np.random.binomial(n, p1, int(n/2))])
xs = np.column_stack([xs, n-xs])
np.random.shuffle(xs)

In [7]:
st_point = np.random.random((2,1))
st_point = np.column_stack([st_point, 1-st_point])

In [8]:
st_point

array([[0.3573748 , 0.6426252 ],
       [0.63721697, 0.36278303]])

In [9]:
results = [em(xs, st_point, max_iter=10000) for i in range(10)]
i, thetas, ll = sorted(results, key=lambda x: x[-1])[-1]
print(i)
for theta in thetas:
    print(theta)
print(ll)

22
[0.70051739 0.29948261]
[0.7934922 0.2065078]
-5585.5899811092095


## Problem

In your office there is a coffee machine that is used by two people. One person likes espresso and the other person likes latte. The person who likes espresso drinks however less cups coffee, only 2/3 of the person, who drinks latte. You have collected data from the coffee machine about how much coffee it uses, but you don't have this data tied to a particular drink. You, however, would like to understand the coffee needed for every drink.

- What is probabilistic model of this experiment?

$$
X \sim Be\left(0.4\right) \mathcal{N}(\mu_1, \sigma_1) + Be\left(0.6\right) \mathcal{N}(\mu_2, \sigma_2)
$$

In [10]:
import torch
import torch.optim as optim
import torch.nn.functional as F

In [29]:
torch.manual_seed(0)

<torch._C.Generator at 0x127533670>

In [30]:
true_pi = np.array([0.4, 0.6])
true_mu = np.array([2.0, 5.0])
true_sigma = np.array([0.5, 1.0])

In [31]:
N = 500

In [32]:
z = np.random.choice(2, size=N, p=true_pi)

In [33]:
data = np.random.normal(loc=true_mu[z], scale=true_sigma[z], size=N)
X = torch.tensor(data, dtype=torch.float32)

In [42]:
num_em_iters = 35
num_gd_steps = 10
learning_rate = 0.1

In [43]:
pi_logits = torch.randn(2, requires_grad=True)
mu = torch.randn(2, requires_grad=True)
log_sigma = torch.randn(2, requires_grad=True)

for em_iter in range(num_em_iters):
    # ---------
    # E-step: Compute Responsibilities
    # ---------
    pi = F.softmax(pi_logits, dim=0)
    sigma = torch.exp(log_sigma)
    
    X_expanded = X.unsqueeze(1)
    mu_expanded = mu.unsqueeze(0)
    sigma_expanded = sigma.unsqueeze(0)
    
    log_prob = -0.5 * torch.log(2 * torch.pi * sigma_expanded**2) \
               - 0.5 * ((X_expanded - mu_expanded)**2 / (sigma_expanded**2))
    
    log_weighted = torch.log(pi) + log_prob
    log_r = log_weighted - torch.logsumexp(log_weighted, dim=1, keepdim=True)
    r = torch.exp(log_r).detach()
    
    # ---------
    # M-step: Update Parameters via Gradient Descent
    # ---------
    optimizer = optim.Adam([pi_logits, mu, log_sigma], lr=learning_rate)
    
    for gd_step in range(num_gd_steps):
        optimizer.zero_grad()
        
        pi = F.softmax(pi_logits, dim=0)
        sigma = torch.exp(log_sigma)
        
        log_prob = -0.5 * torch.log(2 * torch.pi * sigma**2) \
                   - 0.5 * ((X.unsqueeze(1) - mu.unsqueeze(0))**2 / (sigma**2))
        log_component = torch.log(pi) + log_prob
        
        Q = torch.sum(r * log_component)
        loss = -Q
        
        loss.backward()
        optimizer.step()
    
    print(f"EM Iteration {em_iter+1}:")
    print("  Mixture weights (pi):", F.softmax(pi_logits, dim=0).detach().numpy())
    print("  Means (mu):", mu.detach().numpy())
    print("  Std devs (sigma):", torch.exp(log_sigma).detach().numpy())
    print("-" * 50)

EM Iteration 1:
  Mixture weights (pi): [0.02249245 0.97750753]
  Means (mu): [-0.2051959  1.0063932]
  Std devs (sigma): [1.2421808 3.2791967]
--------------------------------------------------
EM Iteration 2:
  Mixture weights (pi): [0.00467231 0.9953277 ]
  Means (mu): [0.6784194 2.0100498]
  Std devs (sigma): [2.0159705 2.4054646]
--------------------------------------------------
EM Iteration 3:
  Mixture weights (pi): [0.00213645 0.9978636 ]
  Means (mu): [1.6234579 3.0021899]
  Std devs (sigma): [1.8568144 1.7868333]
--------------------------------------------------
EM Iteration 4:
  Mixture weights (pi): [0.00152791 0.99847203]
  Means (mu): [2.5901089 3.849558 ]
  Std devs (sigma): [1.40472   1.5768342]
--------------------------------------------------
EM Iteration 5:
  Mixture weights (pi): [0.00175368 0.9982463 ]
  Means (mu): [2.4835446 3.7757068]
  Std devs (sigma): [1.2579525 1.6625428]
--------------------------------------------------
EM Iteration 6:
  Mixture weights