# The Variational EM algorithm of LDA

In [5]:
import utilities
import numpy as np
import re
import string
import collections
import random
from scipy.special import gammaln, psi, polygamma
from functools import reduce
from warnings import warn
import random
import sys

In [6]:
# an example of parsing a document
doc='two-three one, Sth*, five, else ^#%$@#@#'
vocab=dict(zip(('one','two','three','sth','four','else'),range(6)))
parsed_doc=utilities.parse_doc(doc,vocab)
parsed_doc


[(0, 1), (1, 1), (2, 1), (3, 1), (5, 1)]

## The equivalent class of words
Suppose we have a single document. Let 

$$\tilde{w_n}=\text{j such that }w_n^j=1$$
To carry out the inference, first note that for fixed $i,j$, $\{\phi_{ni}: \tilde{w_n}=j\}$ will all be the same. This suggests us to define the following quantity:

$$\tilde{\phi_{ji}}=\phi_{ni}, \text{ where }\tilde{w_n}=j$$
We only need to update $\tilde{\phi}$ in the variational inference (E-step). For fixed $i$, the length of $\tilde{\phi}_{:,i}$ is the cardinality of $\{\tilde{w_n}: n=1,\cdots, N\}$.


In [7]:
def variational_inference(N,k,V,alpha,beta,w,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Variational inference algorithm for document-specific parameters of a single doc in LDA (figure 6 in the paper)
    Arguments:
    N: number of words in the document
    k: number of topics
    V: length of vocabulary
    alpha: corpus-level Dirichlet parameter, k-vector
    beta: corpus-level multinomial parameter, k * V matrix
    w: word id obtained from parsing the document
    conv_threshold: threshold for convergence
    max_iter: maximum number of iterations
    Output:
    A tuple of document specific optimizing parameters $(\gamma^*, \phi^*)$ obtained from variational inference.  
    First element: $\gamma^*$, k-vector
    Second element: the second sum in Eq(9), k*V matrix
    """
    phi0=np.full(shape=(N,k),fill_value=1/k) 
    phi1=phi0
    gamma0=alpha+N/k
    for it in range(max_iter):
        for n in range(N):
            for i in range(k):
                phi1[n,i]=beta[i,w[n]]*np.exp(psi(gamma0[i]))
            phi1[n,]=phi1[n,]/np.sum(phi1[n,])
        gamma1=alpha+np.sum(phi1,axis=0)
        # stop if gamma has converged
        if np.mean(np.abs(gamma0-gamma1))<conv_threshold:
            break
        gamma0=gamma1
        phi0=phi1 
    suff_stat=np.zeros(shape=(V,k))
    for n in range(N):
        suff_stat[w[n],]=suff_stat[w[n],]+phi1[n,]
    return (gamma1,suff_stat.T) 

In [63]:
def e_step(N,k,V,alpha,beta,word_dict,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Variational inference algorithm for document-specific parameters of a single doc in LDA with the equivalent class representation.
    Arguments:
    N: number of words
    k: number of topics
    V: length of vocabulary
    alpha: corpus-level Dirichlet parameter, k-vector
    beta: corpus-level multinomial parameter, k * V matrix
    word_dict: word_dict from parse_doc
    conv_threshold: threshold for convergence
    max_iter: maximum number of iterations
    Output:
    A tuple of document specific optimizing parameters $(\gamma^*, \phi^*)$ obtained from variational inference.  
    First element: $\gamma^*$, k-vector
    Second element: the second sum in Eq(9), k*V matrix
    """
    conv=False
    wordid=list(map(lambda x:x[0],word_dict))
    wordcnt=list(map(lambda x:x[1],word_dict))
    phi0=np.full(shape=(len(wordid),k),fill_value=1/k) # phi_tilde 
    phi1=np.zeros(shape=(len(wordid),k))
    gamma0=alpha+N/k
    for it in range(max_iter):
        for j in range(len(wordid)):
            # the jth row of phi1 corresponds to the word labelled as wordid[j]
            for i in range(k):
                #phi1[j,i]=beta[i,wordid[j]]*np.exp(psi(gamma0[i]))*wordcnt[j]
                phi1[j,i]=beta[i,wordid[j]]*np.exp(psi(gamma0[i]))
            phi1[j,]=phi1[j,]/np.sum(phi1[j,])
        gamma1=alpha+np.sum(phi1*(np.array(wordcnt).reshape((-1,1))),axis=0)
        #gamma1=alpha+np.sum(phi1,axis=0)
        # stop if gamma has converged
        if np.max(np.abs((gamma0-gamma1)))<conv_threshold:
            conv=True
            break
        gamma0=gamma1
        phi0=phi1 
    if not conv:
        warn('Variational inference has not converged. Try more iterations.')
    suff_stat=np.zeros(shape=(V,k))
    suff_stat[wordid,]=phi1*(np.array(wordcnt).reshape((-1,1)))
    print(suff_stat)
    return (gamma1,suff_stat.T) 

## The M-step

The parameters $\beta$ and $\alpha$ are updated in the M step. 

$\beta$ is updated with Eq(9). 

$\alpha$ is updated with Newton-Raphson:
$$\alpha_{new}=\alpha_{old}-H(\alpha_{old})^{-1}g(\alpha_{old}),$$
where $H(\alpha)=(\frac{\partial^2\mathcal{L}}{\partial\alpha_i\partial\alpha_j})_{k\times k}$ is the Hessian matrix and $g(\alpha)=(\frac{\partial \mathcal{L}}{\partial \alpha_i})_{i=1}^k$ is the gradient. 

A.4.2 shows that 
$$\frac{\partial \mathcal{L}}{\partial \alpha_i}=M \left(\Psi\left(\sum_{j=1}^k \alpha_j\right)-\Psi(\alpha_i)\right) + \sum_{d=1}^M \left(\Psi(\gamma_{di})-\Psi(\sum_{j=1}^k\gamma_{dj}) \right),$$
$$\frac{\partial^2\mathcal{L}}{\partial\alpha_i\partial\alpha_j}=\delta(i,j)M\Psi'(\alpha_i)-\Psi'\left(\sum_{j=1}^k\alpha_j\right),$$
i.e.,
$$H(\alpha)=diag(h)+1z1^T,$$
where $$z=-\Psi'\left(\sum_{j=1}^k\alpha_j\right),h=M\Psi'(\alpha)$$
By A.2, we have 
$$(H^{-1}g)_i=\frac{g_i-c}{h_i},$$
where 
$$c=\frac{\sum_{j=1}^k g_j/h_j}{1/z+\sum_{j=1}^k h_j^{-1}}$$

In [64]:
def m_step(M,k,V,suff_stat_list,gamma_list,alpha0,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    M-step in variational EM, maximizing the lower bound on log-likelihood w.r.t. alpha and beta. (Section 5.3)
    Arguments:
    M: number of documents in the corpus
    k: number of topics
    V: length of vocab
    suff_stat_list: M-list of sufficient statistics (k * V matrices), one for each doc
    gamma_list: M-list of gamma's (k-vectors), one for each doc
    alpha0: initialization of alpha in Newton-Raphson
    conv_threshold: convergence threshold in Newton-Raphson
    max_iter: maximum number of iterations in Newton-Raphson
    Output:
    A 2-tuple. 
    First element: beta (k*V matrix)
    Second element: alpha (k*1)
    """
    conv=False
    # update beta
    beta=reduce(lambda x,y: x+y, suff_stat_list)
    beta=beta/np.sum(beta,axis=1).reshape((-1,1))
    # update alpha (Newton-Raphson)
    alpha0=alpha0.reshape(k,1)
    psi_sum_gamma=np.array(list(map(lambda x: psi(np.sum(x)),gamma_list))).reshape((M,1)) # M*1 
    psi_gamma=psi(np.array(gamma_list)) # M*k matrix
    #print('gamma',psi_sum_gamma,psi_gamma)
    for it in range(max_iter):
        psi_sum_alpha=psi(np.sum(alpha0))
        #print('psi_sum_alpha',psi_sum_alpha)
        g=M*(psi_sum_alpha-psi(alpha0)).reshape((k,1))+np.sum(psi_gamma-psi_sum_gamma,axis=0).reshape((k,1)) # k*1 
        #print('g',g)
        h=M*polygamma(1,alpha0)
        #print('h',h)
        z=-polygamma(1,np.sum(alpha0))
        c=np.sum(g/h.reshape((k,1)))/(1/z+np.sum(1/h))
        #print('c',c)
        invHg=(g-c)/h.reshape((k,1))
        #print('invHg',invHg)
        alpha1=alpha0-invHg 
        alpha1=np.clip(alpha1,1e-10,1e100) # for numerical stability
        if np.max(np.abs((alpha1-alpha0)))<conv_threshold:
            print('newton finished at iteration',it)
            conv=True
            break
        alpha0=alpha1
        #print(alpha1)
    if not conv:
        warn('Newton-Raphson has not converged. Try more iterations.')
    return (beta,alpha1)

## The variational EM without smoothing 

In [65]:
def variational_em(Nd,alpha0,beta0,word_dicts,vocab,M,k, conv_threshold=1e-5,max_iter=int(1e6),niter=int(1e6)):
    """
    Input:
    Nd: list of length of documents 
    alpha0: initialization of alpha
    beta0: initialization of beta. DO NOT initialize with identical rows!
    word_dicts: list of word_dict of documents, in the same order as N
    vocab: vocabulary
    M: number of documents
    k: number of topics
    """
    V=len(vocab)
    for it in range(niter):
        print(it)
        e_estimates=list(map(lambda x,y: e_step(x,k,V,alpha0,beta0,y,conv_threshold=conv_threshold,max_iter=max_iter), Nd,word_dicts))
        gamma_list=list(map(lambda x:x[0],e_estimates))
        #print('gammalist',gamma_list)
        suff_stat_list=list(map(lambda x:x[1],e_estimates))
        m_estimates=m_step(M,k,V,suff_stat_list,gamma_list,alpha0,conv_threshold=conv_threshold,max_iter=max_iter)
        alpha1=m_estimates[1]
        beta1=m_estimates[0]
        #print('alpha1=',alpha1,'beta1=',beta1)
        if np.max(np.abs((beta1-beta0)/beta0))<conv_threshold:
            print('vem finished at iteration',it)
            break
        alpha0=alpha1.reshape(k)
        beta0=beta1
        #print(alpha1)
        #print(beta1)
    return (alpha0,beta0)

In [69]:
random.seed(2)
doc1='you '*10+' fish '
doc2='fish '*10+' you '
docs=[doc1,doc2]*100
docs=list(map(lambda x:x*10,docs))
vocab=utilities.make_vocab_from_docs(docs)
word_dicts=list(map(lambda x: utilities.parse_doc(x,vocab),docs))
Nd=list(map(len,docs))
k,M,V=2,len(docs),len(vocab)
alpha0=np.random.random(2)*10
beta0=np.random.random((2,2))
beta0=beta0/np.sum(beta0,axis=1).reshape((-1,1))
conv_threshold=1e-4
max_iter=int(1e2)
niter=10

In [70]:
vocab

{'you': 0, 'fish': 1}

In [72]:
vem=variational_em(Nd,alpha0,beta0,word_dicts,vocab,M,k, conv_threshold=1e-5,max_iter=int(1e4),niter=int(1))

0
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.49124718  2.50875282]
 [ 4.3569972   5.6430028 ]]
[[ 1.88905522  8.11094478]
 [ 0.46061431 99.53938569]]
[[97.491

In [74]:
vem

(array([1.e-010, 1.e+100]),
 array([[0.9537648 , 0.0462352 ],
        [0.09170558, 0.90829442]]))