# Mixed Dirichlet


A mixed Dirichlet random variable $Y$ takes on values in the probability simplex $\Delta_{K-1}$, an assignment $Y=y$ has probability density given by 

\begin{align}
P_{Y}(y|\alpha, w) &= \sum_{f} \mathrm{Gibbs}(f|w) \times \mathrm{Dirichlet}(y|\alpha \odot f)
\end{align}

where $w \in \mathbb R^K$, $\alpha \in \mathbb R^K_{>0}$, $f$ is one of the non-empty faces of the simplex,  by $\alpha \odot f$ we mean the sub-vector of $\alpha$ whose coordinates are associated with the vertices in $f$. 

The distribution over proper faces has probability mass function:
\begin{align}
\mathrm{Gibbs}(f|w) = \frac{\exp(w^\top \phi(f))}{\sum_{f'} \exp(w^\top \phi(f'))}
\end{align}
where $\phi(f) \in \mathbb {-1, 1}^K$ is such that $\phi_k(f) = 1$ if the vertex $\mathbf e_k$ is in the face, and $-1$, otherwise. 


In [None]:
import torch
import torch.distributions as td
from mixed import MixedDirichlet

In [None]:
MixedDirichlet(concentration=torch.ones(3), scores=torch.zeros(3)).sample([10])

In [None]:
MixedDirichlet(concentration=torch.ones(3)/100, scores=torch.zeros(3)).sample([10])

## Understanding the parts 

An efficient GPU-friendly implementation of Mixed Dirichlet distributions takes two auxiliary distributions, namely, a GPU-friendly discrete exponential family over proper faces and a GPU-friendly Dirichlet distribution (for which we can batch distributions of varying dimensionality, from $1$ to $K$).

In [None]:
from bitvector import NonEmptyBitVector

We can use bit-vectors to encode each of the faces of the simplex. A proper face has $1$ to $K$ vertices, so we use a $K$-dimensional bit-vector $f$. If $f_k = 1$ the vertex $\mathbf e_k$ is in the face.

A distribution over the proper faces can be obtained by scoring each of the vertices independently, i.e., 

$P_F(f) \propto \exp(\sum_{k=1}^K (-1)^{1-f_k}w_k)$

The normaliser of this expression sums over the set of proper faces, thus it excludes the 0 bit-vector $\mathbf 0$.

The class `bitvector.NonEmptyBitVector` implements a GPU-friendly version of the necessary procedure, which is based on a directed acyclic graph (DAG) of size $\mathcal O(K)$. 

In [None]:
NonEmptyBitVector(torch.zeros(3)).enumerate_support()

Here a batch of two such distributions, each with a different parameter vector:

In [None]:
NonEmptyBitVector(torch.stack([torch.zeros(3), torch.ones(3)], 0)).sample([1000]).mean(0)

If $A_k \sim \mathrm{Gamma}(\alpha_k, 1)$, and $T = \sum_{k=1}^K A_k$ then
\begin{align}
    \left(\frac{A_1}{T}, \ldots, \frac{A_K}{T} \right)^\top & \sim \mathrm{Dirichlet}(\alpha)
\end{align}


We can use this fact to implement a Dirichlet distribution parametrized by a shared vector of $K$ concentration parameters and a *mask* which identifies which face of the simplex the Dirichlet supports.

The class `dirichlet.MaskedDirichlet` implements such a GPU-friendly distribution.

In [None]:
from dirichlet import MaskedDirichlet

In [None]:
f = NonEmptyBitVector(torch.zeros(3)).sample([2])
f

In the following example, one Dirichlet is defined over the face that contains $\mathbf e_2$ and $\mathbf e_3$, the other is defined over the entire simplex.

In [None]:
MaskedDirichlet(
    mask=torch.stack([torch.tensor([False, True, True]), torch.tensor([True, True, True])], 0), 
    concentration=torch.stack([torch.ones(3), torch.ones(3)/100], 0)
).sample([2])

# Uniform F and Uniform Y|f

In [None]:
import matplotlib.pyplot as plt

In [None]:
def plot_marginals(samples, bins=100):
    D = samples.shape[-1]
    fig, ax = plt.subplots(D, 1, figsize=(4, 2*D), sharex=True)
    for d in range(D):
        _ = ax[d].hist(samples[...,d].flatten().numpy(), bins=bins, density=True)
    return fig, ax

In [None]:
p3d = MixedDirichlet(concentration=torch.ones(3), scores=torch.zeros(3))

In [None]:
p3d.sample([10])

In [None]:
p3d.entropy(), p3d.cross_entropy(p3d), td.kl_divergence(p3d, p3d)

In [None]:
_ = plot_marginals(p3d.sample([1000]), bins=100)

In [None]:
_p = p3d.expand([2, 1])
_p.sample().shape

In [None]:
_p.entropy(), _p.cross_entropy(_p), td.kl_divergence(_p, _p)

In [None]:
_p.faces.cross_entropy(_p.faces).shape

# Max-Ent F and Uniform Y|f

In [None]:
from bitvector import MaxEntropyFaces

In [None]:
pm3d = MixedDirichlet(concentration=torch.ones(3), pmf_n=MaxEntropyFaces.pmf_n(3, 1))

In [None]:
_ = plot_marginals(pm3d.sample(torch.Size([1000])), bins=100)

In [None]:
pm3d.entropy(), pm3d.cross_entropy(pm3d), td.kl_divergence(pm3d, pm3d)

In [None]:
_pm = pm3d.expand([2, 1])
_pm.sample().shape

In [None]:
_pm.entropy(), _pm.cross_entropy(_pm), td.kl_divergence(_pm, _pm)

# VI

In [None]:
p = MixedDirichlet(concentration=torch.ones(5), pmf_n=MaxEntropyFaces.pmf_n(5, 1))
q = MixedDirichlet(concentration=torch.ones(5)/10, scores=torch.zeros(5))

In [None]:
p.batch_shape, p.event_shape

In [None]:
p.sample(torch.Size([10]))

In [None]:
f = p.faces.enumerate_support()

In [None]:
p.faces.log_prob(f).exp(), f.sum(-1)

In [None]:
p.cross_entropy(q)

In [None]:
p.Y(f).cross_entropy(q.Y(f))

In [None]:
p.Y(f).entropy()

In [None]:
p.cross_entropy(q)

In [None]:
td.kl_divergence(p, q)

In [None]:
td.kl_divergence(p.faces, q.faces)