# Dirichlet-multinomial distribution

### Probability mass function

In RNA-seq, for example, [count data](https://arxiv.org/abs/2001.04343) is commonly modeled using a Dirichlet-multinomial distribution, where the multinomial probabilities $\mathbf{p} = (p_1,\ldots, p_d)^T$ follow a Dirichlet distribution with the parameter vector $\boldsymbol{\alpha} = (\alpha_1,\ldots, \alpha_d)^T$ and probability density function (pdf)

$$
\pi(\mathbf{p}) =  \frac{\Gamma(|\boldsymbol \alpha|)}{\prod_{j=1}^d \Gamma(\alpha_j)} \prod_{j=1}^d p_j^{\alpha_j-1},
$$

where $\alpha_j>0$ and $|\boldsymbol \alpha|=\sum_{j=1}^d \alpha_j$.

Given that $\pi(\mathbf{p})$ is a pdf, $\int_{\Delta_d} \pi(\mathbf{p}) d \mathbf{p} = 1$ and hence $\int_{\Delta_d} \prod_{j=1}^d p_j^{\alpha_j-1} d \mathbf{p} = \frac{\prod_{j=1}^d \Gamma(\alpha_j)}{\Gamma(|\boldsymbol \alpha|)}$, where $\Delta_d$ is the unit simplex in $d$ dimensions. Using this property, it is straightforward to show that $\mathbb{E}[p_j] = \frac{\alpha_j}{|\boldsymbol \alpha|}$ and $\mathbb{E}[p_j^2] = \frac{\alpha_j(\alpha_j + 1)}{|\boldsymbol \alpha|(|\boldsymbol \alpha| + 1)}$.

Then for a multivariate count vector $\mathbf{x}=(x_1, \ldots, x_d)^T$ with batch size $|\mathbf{x}|=\sum_{j=1}^d x_j$, the probability mass function (pmf) for Dirichlet-multinomial distribution is

$$
f(\mathbf{x} \mid \boldsymbol \alpha) = \int_{\Delta_d} f(\mathbf{x} \mid \mathbf{p}, \boldsymbol \alpha) \cdot \pi(\mathbf{p}) d \mathbf{p}
$$

$$
= \int_{\Delta_d} \binom{|\mathbf{x}|}{\mathbf{x}} \prod_{j=1}^d p_j^{x_j} \cdot \pi(\mathbf{p}) \, d \mathbf{p}  
= \binom{|\mathbf{x}|}{\mathbf{x}} \frac{\prod_{j=1}^d \Gamma(\alpha_j+x_j)}{\prod_{j=1}^d \Gamma(\alpha_j)} \frac{\Gamma(|\boldsymbol \alpha|)}{\Gamma(|\boldsymbol \alpha|+|\mathbf{x}|)}.
$$

### Log-likelihood

Given independent data points $\mathbf{x}_1, \ldots, \mathbf{x}_n$, the log-likelihood is

$$
L(\boldsymbol \alpha) = \sum_{i=1}^n \ln \binom{|\mathbf{x}_i|}{\mathbf{x}_i} + \sum_{i=1}^n \sum_{j=1}^d [\ln \Gamma(\alpha_j + x_{ij}) - \ln \Gamma(\alpha_j)] - \sum_{i=1}^n [\ln \Gamma(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \ln \Gamma(|\boldsymbol \alpha|)]
$$

$$
= \sum_{i=1}^n \ln \binom{|\mathbf{x}_i|}{\mathbf{x}_i}
+\sum_{i=1}^n \sum_{j=1}^d \sum_{k=0}^{x_{ij}-1} \ln(\alpha_j+k) - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \ln(|\boldsymbol \alpha|+k).
$$

The last equality holds, since $\frac{\Gamma(a + k)}{\Gamma(a)} = a (a + 1) \cdots (a+k-1)$.

### Gradient and Hessian

Given $\frac{\partial}{\partial x} \ln \Gamma(x) = \frac{\Gamma'(x)}{\Gamma(x)} = \Psi(x)$ and $\frac{\partial^2}{\partial x^2} \ln \Gamma(x) = \Psi'(x)$, the score function is $\nabla L(\boldsymbol \alpha) = (\text{D} L(\boldsymbol \alpha))^T$, where 

$$
\frac{\partial}{\partial \alpha_j} L(\boldsymbol \alpha) = \sum_{i=1}^n [\Psi(\alpha_j + x_{ij}) - \Psi(\alpha_j)] - \sum_{i=1}^n [\Psi(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi(|\boldsymbol \alpha|)]
$$

$$
= \sum_{i=1}^n \sum_{k=0}^{x_{ij}-1} \frac{1}{\alpha_j+k} - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{|\boldsymbol \alpha|+k}.
$$

The observed information is $-\nabla^2L(\alpha) = - \text{D} (\text{D} L(\boldsymbol \alpha))^T$, where

$$
-\frac{\partial^2}{\partial \alpha_j \partial \alpha_l} L(\boldsymbol \alpha) = 
\begin{cases}
- \sum_{i=1}^n [\Psi'(\alpha_j + x_{ij}) - \Psi'(\alpha_j)] + \sum_{i=1}^n [\Psi'(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi'(|\boldsymbol \alpha|)], & l = j \\
\sum_{i=1}^n [\Psi'(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi'(|\boldsymbol \alpha|)], & l \neq j
\end{cases}
$$

$$
= 
\begin{cases}
\sum_{i=1}^n \sum_{k=0}^{x_{ij}-1} \frac{1}{(\alpha_j+k)^2} - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & l = j \\
- \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & l \neq j
\end{cases}
$$

Note that the log-likelihood is not a concave function, since $-\frac{\partial^2}{\partial \alpha_1^2} L(\alpha) = - \sum_{i=1}^n [\Psi'(\alpha_1 + x_{i1}) - \Psi'(\alpha_1)] + \sum_{i=1}^n [\Psi'(|\boldsymbol \alpha|+|\mathbf{x}_i|) - \Psi'(|\boldsymbol \alpha|)]$ could be negative. 

The expected Fisher information is $\mathbb{E}[-\nabla^2L(\boldsymbol \alpha)]$, where

$$
\mathbb{E}\left[-\frac{\partial^2}{\partial \alpha_j \partial \alpha_l} L(\boldsymbol \alpha)\right] = 
\begin{cases}
\sum_{i=1}^n \sum_{x_{ij}=0}^{|\boldsymbol{x_i}|} \sum_{k=0}^{x_{ij}-1} \frac{1}{(\alpha_j+k)^2} f(x_{ij}) - \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & l = j \\
- \sum_{i=1}^n \sum_{k=0}^{|\mathbf{x}_i|-1} \frac{1}{(|\boldsymbol \alpha|+k)^2}, & l \neq j
\end{cases}
$$

and $f(x_{ij})$ is the marginal distribution of $f(\boldsymbol{x})$ for $x_{ij}$. Here, Fisher scoring method is inefficient for computing maximum likelihood estimate (MLE), since calculation of the expected information matrix is difficult. So we instead use a positive definite matrix that is approximated from the observed information.

### Other properties

Suppose $(p_1,\ldots,p_d) \in \Delta_d = \{\mathbf{p}: p_i \ge 0, \sum_i p_i = 1\}$ follows a Dirichlet distribution with parameter $\boldsymbol \alpha = (\alpha_1,\ldots,\alpha_d)$. Then taking the derivative with respect to $\alpha_k$ on both sides, 

$$\int_{\Delta_d}\frac{\Gamma(|\boldsymbol \alpha|)}{\prod_{j=1}^d \Gamma(\alpha_j)}\prod_{j=1}^d p_j^{\alpha_j-1} \, d\mathbf{p} = 1$$

$$
\rightarrow \frac{\partial}{\partial \alpha_k}\int_{\Delta_d}\frac{\Gamma(|\alpha|)}{\prod_{j=1}^d \Gamma(\alpha_j)}\prod_{j=1}^d p_j^{\alpha_j-1} \, d\mathbf{p} = 0
$$

$$
\int_{\Delta_d}\bigg(\frac{\Gamma'(|\boldsymbol \alpha|)\prod_{j=1}^{d}\Gamma(\alpha_j)-\Gamma(|\boldsymbol \alpha|)\Gamma'(\alpha_k)\prod_{j\neq k}^{d}\Gamma(\alpha_j)}{\prod_{j=1}^{d}\Gamma(\alpha_j)^{2}}\bigg)\prod_{j=1}^d p_j^{\alpha_j-1}+\frac{\Gamma(|\boldsymbol \alpha|)}{\prod_{j=1}^{d}\Gamma(\alpha_j)}\ln(p_k)\prod_{j=1}^d p_j^{\alpha_j-1} \, d\mathbf{p}=0\\\\
$$

$$
\therefore \mathbb{E}\left[\ln(p_k)\right] = \int_{\Delta_d}\bigg(\frac{\Gamma(|\boldsymbol \alpha|)\Gamma'(\alpha_k)\prod_{j\neq k}^{d}\Gamma(\alpha_j)-\Gamma'(|\boldsymbol \alpha|)\prod_{j=1}^{d}\Gamma(\alpha_j)}{\prod_{j=1}^{d}\Gamma(\alpha_j)^{2}}\bigg)\prod_{j=1}^d p_j^{\alpha_j-1}d\mathbf{p}
$$

$$
= \frac{\Gamma'(\alpha_k)}{\Gamma(\alpha_k)} - \frac{\Gamma'(|\boldsymbol \alpha|)}{\Gamma(|\boldsymbol \alpha|)} 
= \Psi(\alpha_k) - \Psi(|\boldsymbol \alpha|).
$$

### Alternate Hessian matrix

The observed information matrix is not positive definite as mentioned above, but it takes on the Woodbury form, which we can take advantage of to approximate a positive definite matrix. Specifically, 

$$
(\mathbf{A} + c \cdot \boldsymbol{u} \boldsymbol{u}^T)^{-1} = \mathbf{A}^{-1} - c \cdot \frac{1}{1 + c \cdot \boldsymbol{u}^T \mathbf{A}^{-1} \boldsymbol{u}} \mathbf{A}^{-1} \boldsymbol{u} \boldsymbol{u}^T \mathbf{A}^{-1},
$$

where $\mathbf{A}$ is a diagonal matrix with positive entries, $c$ is a negative constant. Since $\mathbf{A} + c \cdot \boldsymbol{u} \boldsymbol{u}^T$ is positive definite if and only if $(\mathbf{A} + c \cdot \boldsymbol{u} \boldsymbol{u}^T)^{-1}$ is positive definite, if we let $c = -0.95 \cdot (\boldsymbol{u}^T \mathbf{A}^{-1} \boldsymbol{u})^{-1}$ whenever $1 + c \cdot \boldsymbol{u}^T \mathbf{A}^{-1} \boldsymbol{u} < 0$, we can guarantee the positive definiteness of the Hessian matrix. 

### Starting point

The following (quasi) method of moment estimator for $\boldsymbol{\alpha}$ would be a good starting point for iterative algorithms:

$$
\alpha_j = \frac{\mathbb{E}[p_j]^2 - \mathbb{E}[p_j]\mathbb{E}[p_j^2]}{\mathbb{E}[p_j^2] - \mathbb{E}[p_j]^2}.
$$

### Implement Newton's algorithm

In [1]:
using SpecialFunctions, LinearAlgebra
polygamma(0, 0.5)  # digamma(0.5)
polygamma(1, 0.5);  # trigamma(0.5)

In [2]:
"""
    dirmult_logpdf(x::Vector, α::Vector)
    
Compute the log-pdf of Dirichlet-multinomial distribution with parameter `α` 
at data point `x`.
"""
function dirmult_logpdf(x::Vector, α::Vector)
    xsum = sum(x)
    αsum = sum(α)
    loglike = logfactorial(xsum)
    for j in 1:length(x)
        loglike = loglike - logfactorial(x[j]) + loggamma(α[j] + x[j]) - loggamma(α[j])
    end
    loglike = loglike - loggamma(xsum + αsum) + loggamma(αsum)
    return loglike
end

function dirmult_logpdf!(r::Vector, X::Matrix, α::Vector)
    for i in 1:size(X, 2)
        r[i] = dirmult_logpdf(X[:, i], α)
    end
    return r
end

"""
    dirmult_logpdf(X, α)
    
Compute the log-pdf of Dirichlet-multinomial distribution with parameter `α` 
at each data point in `X`. Each column of `X` is one data point.
"""
function dirmult_logpdf(X::Matrix, α::Vector)
    r = zeros(size(X, 2))
    dirmult_logpdf!(r, X, α)
end

dirmult_logpdf

In [3]:
"""
    dirmult_newton(X)

Find the MLE of Dirichlet-multinomial distribution using Newton's method.

# Argument
* `X`: an `d`-by-`n` matrix of counts; each column is one data point.

# Optional argument  
* `α0`: a `d` vector of starting point (optional). 
* `maxiters`: the maximum allowable Newton iterations (default 100). 
* `tolfun`: the tolerance for  relative change in objective values (default 1e-6). 

# Output
* `maximum`: the log-likelihood at MLE.   
* `estimate`: the MLE. 
* `gradient`: the gradient at MLE. 
* `hessian`: the Hessian at MLE. 
* `se`: a `d` vector of standard errors. 
* `iterations`: the number of iterations performed.
"""
function dirmult_newton(X::Matrix; α0::Union{Vector, Nothing} = nothing, 
            maxiters::Int = 100, tolfun::Float64 = 1e-6)
    # set default starting point as method of moment estimates
    d, n = size(X)
    tot = sum(X, dims = 1)
    if α0 == nothing
        p = X * Diagonal(1 ./ tot)
        Ep = sum(p, dims = 2) / n
        Ep² = sum(p.^2, dims = 2) / n
        α0 = vec((Ep.^2 - Ep .* Ep²) ./ (Ep² - Ep.^2))
    end
    # calculate initial log-likelihood
    loglold = sum(dirmult_logpdf(X, α0))
    # initialize arrays
    ∇α = zeros(d)
    ∇²α = ones(d, d)
    dir = zeros(d, d)
    A = zeros(d)
    A⁻¹ = zeros(d)
    c = 0
    α = copy(α0)
    logl = loglold
    # Newton loop
    iter = 1
    for outer iter in 1:maxiters
        # evaluate gradient (score)
        α0sum = sum(α0)
        for j in 1:d
            for i in 1:n
                ∇α[j] += polygamma(0, X[j, i] + α0[j]) - polygamma(0, α0[j]) -
                    polygamma(0, tot[i] + α0sum) + polygamma(0, α0sum)
            end
        end
        # compute Newton's direction
        for j in 1:d
            for i in 1:n 
                A[j] -= polygamma(1, X[j, i] +  α0[j]) - polygamma(1, α0[j])
                c += polygamma(1, tot[i] + α0sum) - polygamma(1, α0sum)
            end
            A⁻¹[j] = 1 / A[j]
        end
        # compute Hessian
        ∇²α .= c .* ∇²α
        for j in 1:d
            ∇²α[j, j] += A[j]
        end
        if (1 + c * sum(A⁻¹)) <= 0
            c = -0.95 / sum(A⁻¹)
        end
        BLAS.ger!(-c / (1 + c * sum(A⁻¹)), A⁻¹, A⁻¹, dir)
        for j in 1:d
            dir[j, j] += A⁻¹[j]
        end
        # line search loop
        for lsiter in 1:20
            # step halving
            s = 2.0^(1 - lsiter)
            copy!(α, α0)
            BLAS.gemv!('N', s, dir, ∇α, 1.0, α)
            if (minimum(α) > 0)
                logl = sum(dirmult_logpdf(X, α))
                if (logl > loglold)
                    break
                end
            end
        end
        # check convergence criterion
        if abs(logl - loglold) < tolfun * (abs(loglold) + 1)
            break
        end
        copy!(∇α, zeros(d))
        copy!(∇²α, ones(d, d))
        copy!(dir, zeros(d, d))
        copy!(A, zeros(d))
        c = 0
        copy!(α0, α)
        loglold = logl
    end
    # output
    return logl, α, ∇α, ∇²α, iter
end

dirmult_newton

### Load example data 

Below we build a classifer for handwritten digit recognition. Following figure shows example bitmaps of handwritten digits from U.S. postal envelopes. 

<img src="./handwritten_digits.png" width="250" align="center"/>

Each digit is represented by a $32 \times 32$ bitmap in which each element indicates one pixel with a value of white or black. Each $32 \times 32$ bitmap is divided into blocks of $4 \times 4$, and the number of white pixels are counted in each block. Therefore each handwritten digit is summarized by a vector $\mathbf{x} = (x_1, \ldots, x_{64})$ of length 64 where each element is a count between 0 and 16. 


In [4]:
using DelimitedFiles

optdigits = readdlm("./optdigits.tra", ',', Int64)
size(optdigits)

(3823, 65)

In [5]:
data = copy(transpose(optdigits[:, 1:64]))
digits = optdigits[:, 65];

The training data consists of 3,823 handwritten digits. Each row contains the 64 counts of a digit and the last element (65th element) indicates what digit it is.

In [6]:
dirmult_logpdf(data, ones(size(data, 1)));

### MLE for example data 
We can estimate the MLE for digit 0, digit 1, ..., and digit 9 separately, as follows.

In [7]:
out = zeros(size(data)[1], 10)
for digit in 0:9
    X = data[:, digits .== digit]
    ind = zeros(Int64, size(findall(!iszero, sum(X, dims = 2))))
    for i in 1:size(ind)[1]
        ind[i] = findall(!iszero, sum(X, dims = 2))[i][1]
    end
    X = X[ind, :]
    out[ind, digit + 1] = dirmult_newton(X)[2]
end
out

64×10 Matrix{Float64}:
  0.0        0.0         0.0         …  0.0         0.0         0.0
  0.0374846  0.00802891  0.387042       0.140215    0.0880215   0.0764813
  4.99914    0.514275    3.7928         2.51112     2.49913     1.53133
 14.9145     2.09965     5.21171        5.089       5.83687     3.70784
 12.1794     3.05823     2.49522        5.48423     5.58558     3.74053
  2.45887    1.3218      0.30935     …  4.48988     2.42446     1.51447
  0.0635862  0.137833    0.0157198      1.73032     0.199805    0.298736
  0.0        0.0         0.0            0.179722    0.0         0.0159526
  0.0        0.0         0.0            0.0         0.00219868  0.0
  1.02489    0.0573021   1.72338        0.255966    1.02342     0.667389
 14.6158     1.04117     5.46516     …  3.19455     6.1076      4.02737
 14.7766     3.24159     4.72352        4.01438     4.56622     3.35465
 13.7623     4.1066      4.41509        4.19522     4.08925     3.2949
  ⋮                                  ⋱      

### Comparison with multinomial distribution 

As $|\alpha| \to \infty$ and $\alpha / |\alpha| \to \mathbf{p}$, the Dirichlet-multinomial distribution converges to a multinomial with parameter $\mathbf{p}$. Therefore multinomial can be considered as a special case of Dirichlet-multinomial with $|\alpha|=\infty$. We perform a likelihood ratio test (LRT) whether Dirichlet-multinomial offers a better fit than multinomial for digits 0, 1, ..., 9 respectively.

In [8]:
"""
    mult_logpdf(x::Vector, p::Vector)
    
Compute the log-pdf of multinomial distribution with parameter `p` 
at data point `x`.
"""
function mult_logpdf(x::Vector, p::Vector)
    xsum = sum(x)
    loglike = logfactorial(xsum)
    for j in 1:length(x)
        loglike = loglike - logfactorial(x[j])
    end
    loglike = loglike + dot(x, log.(p))
    return loglike
end

function mult_logpdf!(r::Vector, X::Matrix, p::Vector)
    for i in 1:size(X, 2)
        r[i] = mult_logpdf(X[:, i], p)
    end
    return r
end

"""
    mult_logpdf(X, p)
    
Compute the log-pdf of multinomial distribution with parameter `p` 
at each data point in `X`. Each column of `X` is one data point.
"""
function mult_logpdf(X::Matrix, p::Vector)
    r = zeros(size(X, 2))
    mult_logpdf!(r, X, p)
end

mult_logpdf

In [9]:
loglikeout = zeros(10, 2)
for digit in 0:9
    X = data[:, digits .== digit]
    ind = zeros(Int64, size(findall(!iszero, sum(X, dims = 2))))
    for i in 1:size(ind)[1]
        ind[i] = findall(!iszero, sum(X, dims = 2))[i][1]
    end
    X = X[ind, :]
    loglikeout[digit + 1, 1] = dirmult_newton(X)[1]
    p = vec(sum(X, dims = 2)) / sum(X)
    loglikeout[digit + 1, 2] = sum(mult_logpdf(X, p))
end
loglikeout

10×2 Matrix{Float64}:
 -37358.4  -39592.2
 -42179.3  -54039.2
 -39985.3  -49111.5
 -40519.5  -47089.1
 -43488.8  -57344.1
 -41191.3  -51713.0
 -37702.5  -42597.3
 -40304.1  -49473.0
 -43130.8  -49695.9
 -43709.7  -54577.8

In [10]:
teststat = 2 * (loglikeout[:, 1] - loglikeout[:, 2])

10-element Vector{Float64}:
  4467.550000359493
 23719.94312609102
 18252.218848676377
 13139.208044741405
 27710.49772595527
 21043.37302165093
  9789.588765652516
 18337.851525632417
 13130.003409554833
 21736.264647779986

### Classification in an independent dataset

Now we construct a simple Bayesian rule for handwritten digits recognition:
$$
	\mathbf{x}	\mapsto \arg \max_k \widehat \pi_k f(x|\widehat \alpha_k).
$$
Here we can use the proportion of digit $k$ in the training set as the prior probability $\widehat \pi_k$. We report the performance of our classifier on a test set of 1797 digits. 

In [11]:
test = readdlm("./optdigits.tes", ',', Int64)
size(test)

(1797, 65)

In [12]:
testdata = copy(transpose(test[:, 1:64]))
testdigits = test[:, 65];

In [13]:
prior = zeros(10)
predict = zeros(size(testdata)[2], 11)
for digit in 0:9
    prior[digit + 1] = sum(digits .== digit) / length(digits)
    ind = findall(!iszero, out[:, digit + 1])
    predict[:, digit + 1] = dirmult_logpdf(testdata[ind, :], out[ind, digit + 1]) .+ log(prior[digit + 1])
end

for i in 1:size(predict)[1]
    predict[i, end] = findmax(predict[i, 1:10])[2] - 1
end

sum(testdigits .== predict[:, end]) / size(predict)[1]

0.8747913188647746

In [14]:
versioninfo()

Julia Version 1.6.0
Commit f9720dc2eb (2021-03-24 12:55 UTC)
Platform Info:
  OS: macOS (x86_64-apple-darwin19.6.0)
  CPU: Intel(R) Core(TM) i9-9880H CPU @ 2.30GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-11.0.1 (ORCJIT, skylake)
