# The Variational EM algorithm of LDA

In [None]:
import utilities
import numpy as np
import re
import string
import collections
import random
from scipy.special import gammaln, psi, polygamma
from functools import reduce
from warnings import warn


In [None]:
# an example of parsing a document
doc='two-three one, Sth*, five, else ^#%$@#@#'
vocab=dict(zip(('one','two','three','sth','four','else'),range(6)))
parsed_doc=utilities.parse_doc(doc,vocab)
parsed_doc


## The equivalent class of words
Suppose we have a single document. Let 

$$\tilde{w_n}=\text{j such that }w_n^j=1$$
To carry out the inference, first note that for fixed $i,j$, $\{\phi_{ni}: \tilde{w_n}=j\}$ will all be the same. This suggests us to define the following quantity:

$$\tilde{\phi_{ji}}=\phi_{ni}, \text{ where }\tilde{w_n}=j$$
We only need to update $\tilde{\phi}$ in the variational inference (E-step). For fixed $i$, the length of $\tilde{\phi}_{:,i}$ is the cardinality of $\{\tilde{w_n}: n=1,\cdots, N\}$.


In [None]:
def variational_inference(N,k,V,alpha,beta,w,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Variational inference algorithm for document-specific parameters of a single doc in LDA (figure 6 in the paper)
    Arguments:
    N: number of words in the document
    k: number of topics
    V: length of vocabulary
    alpha: corpus-level Dirichlet parameter, k-vector
    beta: corpus-level multinomial parameter, k * V matrix
    w: word id obtained from parsing the document
    conv_threshold: threshold for convergence
    max_iter: maximum number of iterations
    Output:
    A tuple of document specific optimizing parameters $(\gamma^*, \phi^*)$ obtained from variational inference.  
    First element: $\gamma^*$, k-vector
    Second element: the second sum in Eq(9), k*V matrix
    """
    phi0=np.full(shape=(N,k),fill_value=1/k) 
    phi1=phi0
    gamma0=alpha+N/k
    for it in range(max_iter):
        for n in range(N):
            for i in range(k):
                phi1[n,i]=beta[i,w[n]]*np.exp(psi(gamma0[i]))
            phi1[n,]=phi1[n,]/np.sum(phi1[n,])
        gamma1=alpha+np.sum(phi1,axis=0)
        # stop if gamma has converged
        if np.mean(np.abs(gamma0-gamma1))<conv_threshold:
            break
        gamma0=gamma1
        phi0=phi1 
    suff_stat=np.zeros(shape=(V,k))
    for n in range(N):
        suff_stat[w[n],]=suff_stat[w[n],]+phi1[n,]
    return (gamma1,suff_stat.T) 

In [None]:
def e_step(N,k,V,alpha,beta,word_dict,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Variational inference algorithm for document-specific parameters of a single doc in LDA with the equivalent class representation.
    Arguments:
    N: number of words
    k: number of topics
    V: length of vocabulary
    alpha: corpus-level Dirichlet parameter, k-vector
    beta: corpus-level multinomial parameter, k * V matrix
    word_dict: word_dict from parse_doc
    conv_threshold: threshold for convergence
    max_iter: maximum number of iterations
    Output:
    A tuple of document specific optimizing parameters $(\gamma^*, \phi^*)$ obtained from variational inference.  
    First element: $\gamma^*$, k-vector
    Second element: the second sum in Eq(9), k*V matrix
    """
    conv=False
    wordid=list(word_dict.keys())
    wordcnt=list(word_dict.values())
    phi0=np.full(shape=(len(wordid),k),fill_value=1/k) # phi_tilde 
    phi1=phi0
    gamma0=alpha+N/k
    for it in range(max_iter):
        for j in range(len(wordid)):
            for i in range(k):
                phi1[j,i]=beta[i,wordid[j]]*np.exp(psi(gamma0[i]))*wordcnt[j]
            phi1[j,]=phi1[j,]/np.sum(phi1[j,])
        gamma1=alpha+np.sum(phi1*np.array(wordcnt).reshape((-1,1)),axis=0)
        # stop if gamma has converged
        if np.mean(np.abs(gamma0-gamma1))<conv_threshold:
            conv=True
            break
        gamma0=gamma1
        phi0=phi1 
    if not conv:
        warn('Variational inference has not converged. Try more iterations.')
    suff_stat=np.zeros(shape=(V,k))
    suff_stat[wordid,]=phi1*np.array(wordcnt).reshape((-1,1))
    return (gamma1,suff_stat.T) 

In [None]:
# make sure the two functions produce close results
random.seed(1)
N=5
k=3
V=6
alpha=np.array([1,2,3])
beta=np.random.randint(low=1,high=10,size=(3,6))
beta=beta/np.sum(beta,axis=1).reshape((-1,1))
word_dict=parsed_doc
w=[1,2,0,4,0]
res1=variational_inference(N,k,V,alpha,beta,w,conv_threshold=1e-5)
res2=e_step(N,k,V,alpha,beta,word_dict,conv_threshold=1e-5)
res1[0]-res2[0]

In [None]:
res1[1]-res2[1]

## The M-step

The parameters $\beta$ and $\alpha$ are updated in the M step. 

$\beta$ is updated with Eq(9). 

$\alpha$ is updated with Newton-Raphson:
$$\alpha_{new}=\alpha_{old}-H(\alpha_{old})^{-1}g(\alpha_{old}),$$
where $H(\alpha)=(\frac{\partial^2\mathcal{L}}{\partial\alpha_i\partial\alpha_j})_{k\times k}$ is the Hessian matrix and $g(\alpha)=(\frac{\partial \mathcal{L}}{\partial \alpha_i})_{i=1}^k$ is the gradient. 

A.4.2 shows that 
$$\frac{\partial \mathcal{L}}{\partial \alpha_i}=M \left(\Psi\left(\sum_{j=1}^k \alpha_j\right)-\Psi(\alpha_i)\right) + \sum_{d=1}^M \left(\Psi(\gamma_{di})-\Psi(\sum_{j=1}^k\gamma_{dj}) \right),$$
$$\frac{\partial^2\mathcal{L}}{\partial\alpha_i\partial\alpha_j}=\delta(i,j)M\Psi'(\alpha_i)-\Psi'\left(\sum_{j=1}^k\alpha_j\right),$$
i.e.,
$$H(\alpha)=diag(h)+1z1^T,$$
where $$z=-\Psi'\left(\sum_{j=1}^k\alpha_j\right),h=M\Psi'(\alpha)$$
By A.2, we have 
$$(H^{-1}g)_i=\frac{g_i-c}{h_i},$$
where 
$$c=\frac{\sum_{j=1}^k g_j/h_j}{1/z+\sum_{j=1}^k h_j^{-1}}$$

In [None]:
def m_step(M,k,V,suff_stat_list,gamma_list,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    M-step in variational EM, maximizing the lower bound on log-likelihood w.r.t. alpha and beta. (Section 5.3)
    Arguments:
    M: number of documents in the corpus
    k: number of topics
    V: length of vocab
    suff_stat_list: M-list of sufficient statistics (k * V matrices), one for each doc
    gamma_list: M-list of gamma's (k-vectors), one for each doc
    conv_threshold: convergence threshold in Newton-Raphson
    max_iter: maximum number of iterations in Newton-Raphson
    Output:
    A 2-tuple. 
    First element: beta (k*V matrix)
    Second element: alpha (k*1)
    """
    conv=False
    # update beta
    beta=reduce(lambda x,y: x+y, suff_stat_list)
    beta=beta/np.sum(beta,axis=1).reshape((-1,1))
    # update alpha (Newton-Raphson)
    alpha0=np.full((k,1),fill_value=1/k)
    for it in range(max_iter):
        psi_sum_alpha=psi(np.sum(alpha0))
        psi_sum_gamma=np.array(list(map(lambda x: psi(np.sum(x)),gamma_list))).reshape((M,1)) # M*1 
        psi_gamma=psi(np.array(gamma_list)) # M*k matrix
        g=M*(psi_sum_alpha-psi(alpha0)).reshape((k,1))- np.sum(psi_gamma-psi_sum_gamma,axis=0).reshape((k,1)) # k*1 
        h=M*polygamma(1,alpha0)
        z=-polygamma(1,np.sum(alpha0))
        c=np.sum(g/h.reshape((k,1)))/(1/z+np.sum(1/h))
        invHg=(g-c)/h.reshape((k,1))
        alpha1=alpha0-invHg+1e-10 # 1e-10 is added for numerical stability
        if np.max(np.abs(alpha1-alpha0))<conv_threshold:
            #print('finished at iteration',it)
            conv=True
            break
        alpha0=alpha1
    if not conv:
        warn('Newton-Raphson has not converged. Try more iterations.')
    return (beta,alpha1)

In [None]:
# an example
M=10
suff_stat_list=[res2[1]]*M
gamma_list=[res2[0]]*M
res3=m_step(M,k,V,suff_stat_list,gamma_list,conv_threshold=1e-5,max_iter=int(100))
res4=m_step(M,k,V,suff_stat_list,gamma_list,conv_threshold=1e-5,max_iter=int(1e4))
res3

## The variational EM without smoothing 

In [None]:
def variational_em(Nd,alpha0,beta0,word_dicts,vocab,M,k, conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Input:
    Nd: list of length of documents 
    alpha0: initialization of alpha
    beta0: initialization of beta
    word_dicts: list of word_dict of documents, in the same order as N
    vocab: vocabulary
    M: number of documents
    k: number of topics
    """
    V=len(vocab)
    #
    wordid=list(word_dicts[0].keys())
    wordcnt=list(word_dicts[0].values())
    gamma0=alpha0+N/k
    #
    for it in range(max_iter):
        print(it)
        e_estimates=list(map(lambda x,y: e_step(x,k,V,alpha0,beta0,y,conv_threshold=1e-5,max_iter=int(1e6)), Nd,word_dicts))
        gamma_list=list(map(lambda x:x[0],e_estimates))
        suff_stat_list=list(map(lambda x:x[1],e_estimates))
        m_estimates=m_step(M,k,V,suff_stat_list,gamma_list,conv_threshold=1e-5,max_iter=int(1e6))
        alpha1=m_estimates[1]
        beta1=m_estimates[0]
        if np.max(np.abs(beta1-beta0))<conv_threshold:
            break
        alpha0=alpha1.reshape(k)
        beta0=beta1
    return (alpha0,beta0)

In [None]:
M=2
variational_em([N,N],alpha,beta,[word_dict,word_dict],vocab,M,k)


In [None]:
es=list(variational_em([N,N],alpha,beta,[word_dict,word_dict],vocab,M,k))[0]variational_em([N,N],alpha,beta,[word_dict,word_dict],vocab,M,k)
aaa=np.array([1,2,3]).reshape((3,1))
aaa.reshape(3)

In [None]:
def test1(x):
    return x+['s']

In [None]:
def test2(x):
    x=x*10
    return x 

In [None]:
x=['a']
test2(x)
print(x)
x=test2(x)
test1(x)

In [None]:
variational_em([N,N],alpha,beta,[word_dict,word_dict],vocab,M,k)