Latent Dirichlet Allocation - Variational Inference
====

Based on the paper "Latent Dirchlet Allocation" by David M. Blei, Andrew Y. Ng, Michael I. Jordan

In [1]:
import numpy as np
from numpy import sqrt,mean,square
import numpy.linalg as la
from scipy.special import digamma, polygamma

## Parameters

document:    $m = 1,...,M$

topic:       $z = 1,...,k$

word:        $w = 1,...,N_m$

vocabulary : $v = 1,...,V$

$\alpha: 1 \times k$ vector of topic distribution probabilities

$\beta: k \times v$ matrix of word probabilities for each topic

$\phi: M \times N_m \times k$ matrix of topic probabilities for each word in each document

$\gamma: M \times k$ matrix of topic probabilities for each document

In [2]:
np.random.seed(1337)

### Test data and pre-processing

Run the following first:
```
pip install -U nltk
pip install stop-words
pip install -U gensim
```

In [52]:
!pip install -U nltk
!pip install stop-words
!conda install -y gensim

Requirement already up-to-date: nltk in /opt/conda/lib/python3.4/site-packages
Fetching package metadata: ......
Solving package specifications: .........
  - r-irkernel-0.5-r3.2.2_1a.tar.bz2
  - r-irkernel-0.5-r3.2.2_2.tar.bz2

Package plan for installation in environment /opt/conda:

The following packages will be UPDATED:

    r-irkernel: 0.5-r3.2.2_2 --> 0.5-r3.2.2_1a

Unlinking packages ...
[      COMPLETE      ]|###################################################| 100%
Linking packages ...
[      COMPLETE      ]|###################################################| 100%


In [53]:
from nltk.tokenize import RegexpTokenizer
from stop_words import get_stop_words
from nltk.stem.porter import PorterStemmer
from gensim import corpora, models
import gensim

tokenizer = RegexpTokenizer(r'\w+')

# create English stop words list
en_stop = get_stop_words('en')

# Create p_stemmer of class PorterStemmer
p_stemmer = PorterStemmer()
    
# create sample documents
doc_a = "Brocolli is good to eat. My brother likes to eat good brocolli, but not my mother."
doc_b = "My mother spends a lot of time driving my brother around to baseball practice."
doc_c = "Some health experts suggest that driving may cause increased tension and blood pressure."
doc_d = "I often feel pressure to perform well at school, but my mother never seems to drive my brother to do better."
doc_e = "Health professionals say that brocolli is good for your health." 

# compile sample documents into a list
doc_set = [doc_a, doc_b, doc_c, doc_d, doc_e]

# list for tokenized documents in loop
texts = []

# loop through document list
for i in doc_set:
    
    # clean and tokenize document string
    raw = i.lower()
    tokens = tokenizer.tokenize(raw)

    # remove stop words from tokens
    stopped_tokens = [i for i in tokens if not i in en_stop]
    
    # stem tokens
    stemmed_tokens = [p_stemmer.stem(i) for i in stopped_tokens]
    
    # add tokens to list
    texts.append(stemmed_tokens)

# turn our tokenized documents into a id <-> term dictionary
dictionary = corpora.Dictionary(texts)
    
# convert tokenized documents into a document-term matrix
corpus = [dictionary.doc2bow(text) for text in texts]

In [54]:
corpus

[[(0, 1), (1, 1), (2, 2), (3, 2), (4, 1), (5, 2)],
 [(0, 1), (4, 1), (6, 1), (7, 1), (8, 1), (9, 1), (10, 1), (11, 1), (12, 1)],
 [(9, 1),
  (13, 1),
  (14, 1),
  (15, 1),
  (16, 1),
  (17, 1),
  (18, 1),
  (19, 1),
  (20, 1),
  (21, 1)],
 [(0, 1),
  (4, 1),
  (9, 1),
  (21, 1),
  (22, 1),
  (23, 1),
  (24, 1),
  (25, 1),
  (26, 1),
  (27, 1),
  (28, 1),
  (29, 1)],
 [(3, 1), (5, 1), (19, 2), (30, 1), (31, 1)]]

In [3]:
M = 3
k = 10
N = np.random.randint(50,size=M)
V = 20

print('N: {0}'.format(N))

N: [23 28 40]


In [4]:
# Generate random "documents"
doc1 = np.random.randint(V,size=N[0])
doc2 = np.random.randint(V,size=N[1])
doc3 = np.random.randint(V,size=N[2])

w = np.array((doc1,doc2,doc3))
w

array([ array([ 7,  7, 18, 18,  8,  9,  6,  1,  6, 18,  2,  9,  8, 11,  1, 19, 14,
       17, 15, 19,  3,  8,  8]),
       array([ 7,  9,  4,  1, 12,  6,  4, 18, 14, 18,  3,  9,  2, 10,  4,  8, 14,
        6,  3,  8, 15,  7,  4,  3, 18,  9, 17, 13]),
       array([18, 12, 16, 13, 19, 18,  1, 13,  0, 17,  9, 13,  2, 16, 16,  7,  6,
       12, 19,  0, 16, 18,  5,  9, 16, 18, 18, 19,  6, 10, 13, 10, 19, 13,
       16,  3,  9, 12, 13, 15])], dtype=object)

In [11]:
# Word #11 in document 2 (w_dn)
w[1][10]

3

In [28]:
[doc == 3 for doc in w]

[array([False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False,  True, False, False], dtype=bool),
 array([False, False, False, False, False, False, False, False, False,
        False,  True, False, False, False, False, False, False, False,
         True, False, False, False, False,  True, False, False, False, False], dtype=bool),
 array([False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False,  True,
        False, False, False, False], dtype=bool)]

In [29]:
[doc == 3 for doc in w]*w

array([ array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0]),
       array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0,
       3, 0, 0, 0, 0]),
       array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0])], dtype=object)

### Initialize parameters $\alpha, \beta, \phi$ and $\gamma$

In [5]:
alpha = np.random.dirichlet(np.ones(k),1)
beta = np.random.dirichlet(np.ones(k),V)

phi = np.array([1/k*np.ones([N[m],k]) for m in range(M)])
gamma = np.tile(alpha,(M,1)) + np.tile(N/k,(k,1)).T

In [5]:
alpha

array([[ 0.03614628,  0.0682012 ,  0.04293727,  0.08103403,  0.03368725,
         0.41405727,  0.14638083,  0.01358403,  0.05414736,  0.10982448]])

In [36]:
beta.shape

(20, 10)

In [42]:
phi.shape

(3,)

In [39]:
gamma.shape

(3, 10)

### Optimize variational parameters $\phi$ and $\gamma$

In [7]:
# TODO: Split phi and gamma optimization apart for parallelization purposes
# TODO: See if some sort of vectorization is possible for speed-up
def optVarParams(alpha,beta,phi,gamma,words):
    ## Optimize phi
    for m in range(M):
        for n in range(N[m]):
            for i in range(k):
                phi[m][n,i] = beta[words[m][n],i] * np.exp(digamma(gamma[m,i]) - digamma(np.sum(gamma[m,:])))
            # Normalize across states so phi represents probability over states for each word
            phi[m][n,:] = phi[m][n,:]/sum(phi[m][n,:])
    
    ## Optimize gamma
    gamma = np.tile(alpha,(M,1)) + np.array(list(map(lambda x: np.sum(x,axis=0),phi)))
    
    return phi,gamma

In [124]:
phi,gamma = optVarParams(alpha,beta,phi,gamma,w)

### Estimate model parameters $\alpha$ and $\beta$

In [8]:
def estModParams(alpha,beta,phi,gamma,words):
    ## Optimize beta
    for j in range (V):
        w_dnj = [np.tile((word == j),(k,1)).T for word in w]
        beta[j,:] = np.sum(np.array(list(map(lambda x: np.sum(x,axis=0),phi*w_dnj))),axis=0)
        # Normalize across states so beta represents probability over states for each word
        beta[j,:] = beta[j,:]/sum(beta[j,:])
    
    ## Optimize alpha
    nr_max_iters = 1000
    tol = 10**-4
    for it in range(nr_max_iters):
        alpha_old = alpha
        
        #  Calculate gradient 
        g = M*(digamma(np.sum(alpha))-digamma(alpha)) + np.sum(digamma(gamma)-np.tile(digamma(np.sum(gamma,axis=1)),(k,1)).T,axis=0)
        #  Calculate Hessian diagonal component
        h = -M*polygamma(1,alpha) 
        #  Calculate Hessian constant component
        z = polygamma(1,np.sum(alpha))
        #  Calculate constant
        c = np.sum(g/h)/(z**(-1)+np.sum(h**(-1)))

        #  Update alpha
        alpha = alpha - (g-c)/h
        
        #  Check convergence
        if sqrt(mean(square(alpha-alpha_old)))<tol:
            break

    return alpha,beta

In [126]:
estModParams(alpha,beta,phi,gamma,w)

(array([[ 2.81550093,  3.38058491,  2.79467498,  2.94223941,  3.41450454,
          2.86980884,  2.86901428,  2.84627664,  2.87613474,  2.92062181]]),
 array([[ 0.097743  ,  0.10585469,  0.09744617,  0.09955311,  0.10634451,
          0.09851787,  0.09850653,  0.09818197,  0.09860821,  0.09924393],
        [ 0.09742693,  0.10667486,  0.09708858,  0.09949035,  0.10723337,
          0.09831022,  0.09829729,  0.09792731,  0.0984132 ,  0.09913789],
        [ 0.09680203,  0.10829699,  0.09638168,  0.09936588,  0.10899151,
          0.09789947,  0.0978834 ,  0.09742371,  0.09802742,  0.0989279 ],
        [ 0.09766106,  0.10606733,  0.09735346,  0.09953684,  0.10657495,
          0.09846404,  0.09845228,  0.09811595,  0.09855765,  0.09921644],
        [ 0.09755311,  0.10634772,  0.09723137,  0.09951523,  0.10687889,
          0.09839302,  0.09838073,  0.09802892,  0.09849095,  0.09918007],
        [ 0.09674642,  0.10844134,  0.09631877,  0.09935481,  0.10914796,
          0.09786292,  0.09784

### Expectation Maximization

#### Convergence Criterion
The variational inference parameter $\gamma$ contains the topic likelihoods of every document and is thus what is of interest here.

Calculate root-mean-square of the change in $\gamma$

In [6]:
def converged(gamma,gamma_old,convergence):
    print(sqrt(mean(square(gamma-gamma_old))))
    return sqrt(mean(square(gamma-gamma_old))) < convergence

#### Inference by iterative EM
Continue until convergence criterion above met

In [11]:
convergence = 10**(-2)
successfully_Converged = False
max_iters = 10**2

for iters in range(max_iters):
    print(iters)
    gamma_old = gamma
    phi,gamma  = optVarParams(alpha,beta,phi,gamma,w)
    alpha,beta = estModParams(alpha,beta,phi,gamma,w)
    if converged(gamma,gamma_old,convergence):
        successfully_Converged = True
        break

0
2.20433214055
1
2.20334273545
2
2.20235415433
3
2.20136640577
4
2.20037951567
5
2.19939343946
6
2.19840820019
7
2.19742379179
8
2.19644021742
9
2.19545746284
10
2.19447553648
11
2.19349444871
12
2.19251417303
13
2.19153472446
14
2.19055611618
15
2.18957829101
16
2.1886012964
17
2.18762513947
18
2.18664978653
19
2.18567523958
20
2.18470152222
21
2.1837286232
22
2.18275653381
23
2.18178525013
24
2.18081477581
25
2.17984510993
26
2.17887625735
27
2.17790821894
28
2.17694096689
29
2.17597454223
30
2.17500889492
31
2.17404406208
32
2.17308003428
33
2.17211679674
34
2.17115437442
35
2.17019274188
36
2.16923190095
37
2.16827186003
38
2.16731261905
39
2.16635415461
40
2.1653964896
41
2.16443961308
42
2.1634835372
43
2.16252823951
44
2.16157374906
45
2.16062002744
46
2.15966707817
47
2.1587149326
48
2.15776357297
49
2.15681297017
50
2.15586317507
51
2.15491414945
52
2.15396590942
53
2.15301843778
54
2.1520717463
55
2.15112581445
56
2.15018067125
57
2.14923630154
58
2.14829271492
59
2.14734987

In [12]:
alpha

array([[  1.41852574e+00,   1.09298790e+00,   6.03848641e-01,
          1.51726379e+00,   6.89034002e-01,   5.98795853e-01,
          7.45652048e-01,   2.10881461e+04,   3.79338058e+00,
          1.14812382e+00]])

In [13]:
gamma

array([[  1.41963960e+00,   1.09376360e+00,   6.04160439e-01,
          1.51848247e+00,   6.89420052e-01,   5.99103375e-01,
          7.46089436e-01,   2.11044758e+04,   3.79710371e+00,
          1.14895576e+00],
       [  1.41963960e+00,   1.09376360e+00,   6.04160439e-01,
          1.51848247e+00,   6.89420052e-01,   5.99103375e-01,
          7.46089436e-01,   2.11094758e+04,   3.79710371e+00,
          1.14895576e+00],
       [  1.41963960e+00,   1.09376360e+00,   6.04160439e-01,
          1.51848247e+00,   6.89420052e-01,   5.99103375e-01,
          7.46089436e-01,   2.11214758e+04,   3.79710371e+00,
          1.14895576e+00]])

### Tests 
Testing out syntax and array dimensions

In [11]:
np.sum(np.array(list(map(lambda x: np.sum(x,axis=0),phi))),axis=0)

array([ 9.1,  9.1,  9.1,  9.1,  9.1,  9.1,  9.1,  9.1,  9.1,  9.1])

In [49]:
[doc == 3 for doc in w]*w

array([ array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0]),
       array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0,
       3, 0, 0, 0, 0]),
       array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0])], dtype=object)