# Dirichlet-multinomial distribution

### Probability mass function

In RNA-seq, for example, [count data](https://arxiv.org/abs/2001.04343) is commonly modeled using a Dirichlet-multinomial distribution, where the multinomial probabilities $\mathbf{p} = (p_1,\ldots, p_d)^T$ follow a Dirichlet distribution with the parameter vector $\boldsymbol{\alpha} = (\alpha_1,\ldots, \alpha_d)^T$ and probability density function (pdf)

$$
\pi(\mathbf{p}) =  \frac{\Gamma(|\boldsymbol \alpha|)}{\prod_{j=1}^d \Gamma(\alpha_j)} \prod_{j=1}^d p_j^{\alpha_j-1},
$$

where $\alpha_j>0$ and $|\boldsymbol \alpha|=\sum_{j=1}^d \alpha_j$.

Given that $\pi(\mathbf{p})$ is a pdf, $\int_{\Delta_d} \pi(\mathbf{p}) d \mathbf{p} = 1$ and hence $\int_{\Delta_d} \prod_{j=1}^d p_j^{\alpha_j-1} d \mathbf{p} = \frac{\prod_{j=1}^d \Gamma(\alpha_j)}{\Gamma(|\boldsymbol \alpha|)}$, where $\Delta_d$ is the unit simplex in $d$ dimensions. Using this property, it is straightforward to show that $\mathbb{E}[p_j] = \frac{\alpha_j}{|\boldsymbol \alpha|}$ and $\mathbb{E}[p_j^2] = \frac{\alpha_j(\alpha_j + 1)}{|\boldsymbol \alpha|(|\boldsymbol \alpha| + 1)}$.

Then for a multivariate count vector $\mathbf{x}=(x_1, \ldots, x_d)^T$ with batch size $|\mathbf{x}|=\sum_{j=1}^d x_j$, the probability mass function (pmf) for Dirichlet-multinomial distribution is

$$
f(\mathbf{x} \mid \boldsymbol \alpha) = \int_{\Delta_d} f(\mathbf{x} \mid \mathbf{p}, \boldsymbol \alpha) \cdot \pi(\mathbf{p}) d \mathbf{p}
$$

$$
= \int_{\Delta_d} \binom{|\mathbf{x}|}{\mathbf{x}} \prod_{j=1}^d p_j^{x_j} \cdot \pi(\mathbf{p}) \, d \mathbf{p}  
= \binom{|\mathbf{x}|}{\mathbf{x}} \frac{\prod_{j=1}^d \Gamma(\alpha_j+x_j)}{\prod_{j=1}^d \Gamma(\alpha_j)} \frac{\Gamma(|\boldsymbol \alpha|)}{\Gamma(|\boldsymbol \alpha|+|\mathbf{x}|)}.
$$

### Log-likelihood

Given independent data points $\mathbf{x}_1, \ldots, \mathbf{x}_n$, the log-likelihood is

$$
L(\alpha) = \sum_{i=1}^n \ln \binom{|\mathbf{x}_i|}{\mathbf{x}_i} + \sum_{i=1}^n \sum_{j=1}^d [\ln \Gamma(\alpha_j + x_{ij}) - \ln \Gamma(\alpha_j)] - \sum_{i=1}^n [\ln \Gamma(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \ln \Gamma(|\boldsymbol \alpha|)]
$$

$$
= \sum_{i=1}^n \ln \binom{|\mathbf{x}_i|}{\mathbf{x}_i}
+\sum_{i=1}^n \sum_{j=1}^d \sum_{k=0}^{x_{ij}-1} \ln(\alpha_j+k) - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \ln(|\boldsymbol \alpha|+k).
$$

The last equality holds, since $\frac{\Gamma(a + k)}{\Gamma(a)} = a (a + 1) \cdots (a+k-1)$.

### Gradient and Hessian

Given $\frac{\partial}{\partial x} \ln \Gamma(x) = \frac{\Gamma'(x)}{\Gamma(x)} = \Psi(x)$ and $\frac{\partial^2}{\partial x^2} \ln \Gamma(x) = \Psi'(x)$, the score function is $\nabla L(\boldsymbol \alpha) = d L(\boldsymbol \alpha)^T$, where 

$$
\frac{\partial}{\partial \alpha_j} L(\alpha) = \sum_{i=1}^n [\Psi(\alpha_j + x_{ij}) - \Psi(\alpha_j)] - \sum_{i=1}^n [\Psi(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi(|\boldsymbol \alpha|)]
$$

$$
= \sum_{i=1}^n \sum_{k=0}^{x_{ij}-1} \frac{1}{\alpha_j+k} - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{|\boldsymbol \alpha|+k}.
$$

The observed information is $-\nabla^2L(\alpha) = - d^2 L(\boldsymbol \alpha)$, where

$$
-\frac{\partial^2}{\partial \alpha_j \partial \alpha_k} L(\alpha) = 
\begin{cases}
- \sum_{i=1}^n [\Psi'(\alpha_j + x_{ij}) - \Psi'(\alpha_j)] + \sum_{i=1}^n [\Psi'(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi'(|\boldsymbol \alpha|)], & k = j \\
\sum_{i=1}^n [\Psi'(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi'(|\boldsymbol \alpha|)], & k \neq j
\end{cases}
$$

$$
= 
\begin{cases}
\sum_{i=1}^n \sum_{k=0}^{x_{ij}-1} \frac{1}{(\alpha_j+k)^2} - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & k = j \\
- \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & k \neq j
\end{cases}
$$

Note that the log-likelihood is not a concave function, since $-\frac{\partial^2}{\partial \alpha_1^2} L(\alpha) = - \sum_{i=1}^n [\Psi'(\alpha_1 + x_{i1}) - \Psi'(\alpha_1)] + \sum_{i=1}^n [\Psi'(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi'(|\boldsymbol \alpha|)]$ could be negative. 

The expected Fisher information is $\mathbb{E}[-\nabla^2L(\alpha)]$, where

$$
\mathbb{E}\big[-\frac{\partial^2}{\partial \alpha_j \partial \alpha_k} L(\alpha)\big] = 
\begin{cases}
\sum_{i=1}^n \sum_{x_{ij}=0}^{x_{ij}=|\boldsymbol{x_i}|} \sum_{k=0}^{x_{ij}-1} \frac{1}{(\alpha_j+k)^2} f(x_{ij}) - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & k = j \\
- \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & k \neq j
\end{cases}
$$

and $f(x_{ij})$ is the marginal distribution of $f(\boldsymbol{x})$ for $x_{ij}$. Here, Fisher scoring method is inefficient for computing maximum likelihood estimator (MLE), since calculation of the expected information matrix is difficult. So we instead use a positive definite matrix that is approximated from the observed information.

### Other properties

Suppose $(p_1,\ldots,p_d) \in \Delta_d = \{\mathbf{p}: p_i \ge 0, \sum_i p_i = 1\}$ follows a Dirichlet distribution with parameter $\alpha = (\alpha_1,\ldots,\alpha_d)$. Then taking the derivative with respect to $\alpha_k$ on both sides, 

$$\int_{\Delta_d}\frac{\Gamma(|\boldsymbol \alpha|)}{\prod_{j=1}^d \Gamma(\alpha_j)}\prod_{j=1}^d p_j^{\alpha_j-1} \, d\mathbf{p} = 1$$

$$
\rightarrow \frac{\partial}{\partial \alpha_k}\int_{\Delta_d}\frac{\Gamma(|\alpha|)}{\prod_{j=1}^d \Gamma(\alpha_j)}\prod_{j=1}^d p_j^{\alpha_j-1} \, d\mathbf{p} = 0
$$

$$
\int_{\Delta_d}\bigg(\frac{\Gamma'(|\boldsymbol \alpha|)\prod_{j=1}^{d}\Gamma(\alpha_j)-\Gamma(|\boldsymbol \alpha|)\Gamma'(\alpha_k)\prod_{j\neq k}^{d}\Gamma(\alpha_j)}{\prod_{j=1}^{d}\Gamma(\alpha_j)^{2}}\bigg)\prod_{j=1}^d p_j^{\alpha_j-1}+\frac{\Gamma(|\boldsymbol \alpha|)}{\prod_{j=1}^{d}\Gamma(\alpha_j)}\ln(p_k)\prod_{j=1}^d p_j^{\alpha_j-1} \, d\mathbf{p}=0\\\\
$$

$$
\therefore \mathbb{E}\big[\ln(p_k)\big] = \int_{\Delta_d}\bigg(\frac{\Gamma(|\boldsymbol \alpha|)\Gamma'(\alpha_k)\prod_{j\neq k}^{d}\Gamma(\alpha_j)-\Gamma'(|\boldsymbol \alpha|)\prod_{j=1}^{d}\Gamma(\alpha_j)}{\prod_{j=1}^{d}\Gamma(\alpha_j)^{2}}\bigg)\prod_{j=1}^d p_j^{\alpha_j-1}d\mathbf{p}
$$

$$
= \frac{\Gamma'(\alpha_k)}{\Gamma(\alpha_k)} - \frac{\Gamma'(|\boldsymbol \alpha|)}{\Gamma(|\boldsymbol \alpha|)} 
= \Psi(\alpha_k) - \Psi(|\boldsymbol \alpha|).
$$

### Alternative Hessian matrix

The observed information matrix is not positive definite as mentioned above, but it takes the Woodbury form, which we can take advantage of to approximate a positive definite matrix.

### Starting point

Method of moment estimator for $\boldsymbol{\alpha}$ would be a good starting point for iterative algorithms.

### Implement Newton's algorithm

In [1]:
using SpecialFunctions
polygamma(0, 0.5)  # digamma(0.5)
polygamma(1, 0.5);  # trigamma(0.5)

In [2]:
"""
    dirmult_logpdf(x::Vector, α::Vector)
    
Compute the log-pdf of Dirichlet-multinomial distribution with parameter `α` 
at data point `x`.
"""
function dirmult_logpdf(x::Vector, α::Vector)
    xsum = sum(x)
    αsum = sum(α)
    loglike = logfactorial(xsum)
    for j in 1:length(x)
        loglike = loglike - logfactorial(x[j]) + loggamma(α[j] + x[j]) - loggamma(α[j])
    end
    loglike = loglike - loggamma(xsum + αsum) + loggamma(αsum)
    return loglike
end

function dirmult_logpdf!(r::Vector, X::Matrix, α::Vector)
    for i in 1:size(X, 2)
        r[i] = dirmult_logpdf(X[:, i], α)
    end
    return r
end

"""
    dirmult_logpdf(X, α)
    
Compute the log-pdf of Dirichlet-multinomial distribution with parameter `α` 
at each data point in `X`. Each column of `X` is one data point.
"""
function dirmult_logpdf(X::Matrix, α::Vector)
    r = zeros(size(X, 2))
    dirmult_logpdf!(r, X, α)
end

dirmult_logpdf

### Load data 

Below we build a classifer for handwritten digit recognition. Following figure shows example bitmaps of handwritten digits from U.S. postal envelopes. 

<img src="./handwritten_digits.png" width="250" align="center"/>

Each digit is represented by a $32 \times 32$ bitmap in which each element indicates one pixel with a value of white or black. Each $32 \times 32$ bitmap is divided into blocks of $4 \times 4$, and the number of white pixels are counted in each block. Therefore each handwritten digit is summarized by a vector $\mathbf{x} = (x_1, \ldots, x_{64})$ of length 64 where each element is a count between 0 and 16. 


In [3]:
using DelimitedFiles
using LinearAlgebra

optdigits = readdlm("./optdigits.tra", ',', Int64)
size(optdigits)

(3823, 65)

In [4]:
data = copy(transpose(optdigits[:, 1:64]))
digits = optdigits[:, 65];

The training data consists of 3823 handwritten digits. Each row contains the 64 counts of a digit and the last element (65th element) indicates what digit it is.

In [5]:
dirmult_logpdf(data, ones(size(data, 1)));

In [6]:
versioninfo()

Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
