# The Variational EM algorithm of LDA

In [32]:
import numpy as np
import re
import string
import collections
import random
from scipy.special import psi

In [33]:
def parse_doc(doc,vocab):
    """
    Parse a single document. 
    Arguments:
    doc: document string
    vocab: a dictionary that maps words to integers
    Output:
    A dictionary, where the keys are words appeared in the doc, labeled with the integers in the vocab dictionary (the set of $\tilde{w_n}$), 
        and the values are counts of the words.
    The words that are not in vocab will be ignored.
    """
    doc=doc.lower()
    doc=re.sub(r'-',' ',doc)
    doc=re.sub(r' +',' ',doc) # turn multiple spaces into a single space
    doc=re.sub(r'[^a-z ]','',doc) # remove anything that is not a-z or space
    words=doc.split()
    word_vocab=[vocab.get(word,-1) for word in words]
    words_dict=collections.Counter(word_vocab)
    del words_dict[-1] # ignore the words outside the vocabulary
    #wordid=words_dict.keys()
    #wordcnt=words_dict.values()
    return words_dict

In [62]:
# an example of parsing a document
doc='two-three one, Four*, five, one ^#%$@#@#'
vocab=dict(zip(('one','two','three','four','sth','else'),range(6)))
parsed_doc=parse_doc(doc,vocab)
parsed_doc

Counter({1: 1, 2: 1, 0: 2, 3: 1})

## The equivalent class of words
Suppose we have a single document. Let 

$$\tilde{w_n}=\text{j such that }w_n^j=1$$
To carry out the inference, first note that for fixed $i,j$, $\{\phi_{ni}: \tilde{w_n}=j\}$ will all be the same. This suggests us to define the following quantity:

$$\tilde{\phi_{ji}}=\phi_{ni}, \text{ where }\tilde{w_n}=j$$
We only need to update $\tilde{\phi}$ in the variational inference (E-step). For fixed $i$, the length of $\tilde{\phi}_{:,i}$ is the cardinality of $\{\tilde{w_n}: n=1,\cdots, N\}$.


In [71]:
def variational_inference(N,k,V,alpha,beta,w,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Variational inference algorithm for document-specific parameters of a single doc in LDA (figure 6 in the paper)
    Arguments:
    N: number of words in the document
    k: number of topics
    V: length of vocabulary
    alpha: corpus-level Dirichlet parameter, k-vector
    beta: corpus-level multinomial parameter, k * V matrix
    w: word id obtained from parsing the document
    conv_threshold: threshold for convergence
    max_iter: maximum number of iterations
    Output:
    A tuple of document specific optimizing parameters $(\gamma^*, \phi^*)$ obtained from variational inference.  
    First element: $\gamma^*$, k-vector
    Second element: the second sum in Eq(9)
    """
    phi0=np.full(shape=(N,k),fill_value=1/k) 
    phi1=phi0
    gamma0=alpha+N/k
    for it in range(max_iter):
        for n in range(N):
            for i in range(k):
                phi1[n,i]=beta[i,w[n]]*np.exp(psi(gamma0[i]))
            phi1[n,]=phi1[n,]/np.sum(phi1[n,])
        gamma1=alpha+np.sum(phi1,axis=0)
        # stop if gamma has converged
        if np.mean(np.abs(gamma0-gamma1))<conv_threshold:
            break
        gamma0=gamma1
        phi0=phi1 
    suff_stat=np.zeros(shape=(V,k))
    for n in range(N):
        suff_stat[w[n],]=suff_stat[w[n],]+phi1[n,]
    return (gamma1,suff_stat) 

In [72]:
def e_step(N,k,V,alpha,beta,word_dict,conv_threshold=1e-5,max_iter=int(1e6)):
    """
    Variational inference algorithm for document-specific parameters of a single doc in LDA with the equivalent class representation.
    Arguments:
    N: number of words
    k: number of topics
    V: length of vocabulary
    alpha: corpus-level Dirichlet parameter, k-vector
    beta: corpus-level multinomial parameter, k * V matrix
    word_dict: word_dict from parse_doc
    conv_threshold: threshold for convergence
    max_iter: maximum number of iterations
    Output:
    A tuple of document specific optimizing parameters $(\gamma^*, \phi^*)$ obtained from variational inference.  
    First element: $\gamma^*$, k-vector
    Second element: the second sum in Eq(9)
    """
    wordid=list(word_dict.keys())
    wordcnt=list(word_dict.values())
    phi0=np.full(shape=(len(wordid),k),fill_value=1/k) # phi_tilde 
    phi1=phi0
    gamma0=alpha+N/k
    for it in range(max_iter):
        for j in range(len(wordid)):
            for i in range(k):
                phi1[j,i]=beta[i,wordid[j]]*np.exp(psi(gamma0[i]))*wordcnt[j]
            phi1[j,]=phi1[j,]/np.sum(phi1[j,])
        gamma1=alpha+np.sum(phi1*np.array(wordcnt).reshape((-1,1)),axis=0)
        # stop if gamma has converged
        if np.mean(np.abs(gamma0-gamma1))<conv_threshold:
            break
        gamma0=gamma1
        phi0=phi1 
    suff_stat=np.zeros(shape=(V,k))
    suff_stat[wordid,]=phi1*np.array(wordcnt).reshape((-1,1))
    return (gamma1,suff_stat) 

In [73]:
# make sure the two functions produce close results
random.seed(1)
N=5
k=3
V=6
alpha=np.array([1,2,3])
beta=np.random.randint(low=1,high=10,size=(3,4))
beta=beta/np.sum(beta,axis=1).reshape((-1,1))
word_dict=parsed_doc
w=[1,2,0,3,0]
res1=variational_inference(N,k,V,alpha,beta,w,conv_threshold=1e-5,max_iter=int(1e6))
res2=e_step(N,k,V,alpha,beta,word_dict,conv_threshold=1e-5,max_iter=int(1e6))
res1[0]-res2[0]

array([0., 0., 0.])

In [69]:
res1[1]-res2[1]

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])