Latent Dirichlet Allocation - Variational Inference
====

Based on the paper "Latent Dirchlet Allocation" by David M. Blei, Andrew Y. Ng, Michael I. Jordan

In [1]:
import numpy as np
from numpy import sqrt,mean,square
import numpy.linalg as la
from scipy.special import digamma, polygamma

In [45]:
!git config --global user.email "kevinjliang2011@gmail.com"
!git config --global user.name "Kevin Liang"

## Parameters

document:    $m = 1,...,M$

topic:       $z = 1,...,k$

word:        $w = 1,...,N_m$

vocabulary : $v = 1,...,V$

$\alpha: 1 \times k$ Model parameter - vector of topic distribution probabilities for each document

$\beta: k \times v$ Model parameter - matrix of word probabilities for each topic

$\phi: M \times N_m \times k$ Variational parameter - matrix of topic probabilities for each word in each document

$\gamma: M \times k$ Variational parameter - matrix of topic probabilities for each document

In [138]:
np.random.seed(1337)

In [139]:
M = 300
k = 10
N = np.random.randint(150,200,size=M)
V = 30

print('N: {0}'.format(N))

N: [173 178 190 189 175 189 176 168 170 158 159 156 176 173 174 151 177 179
 156 172 152 191 190 161 151 173 169 196 167 197 177 153 170 158 158 157
 177 159 154 181 183 162 156 154 196 170 168 176 153 191 184 192 154 158
 164 188 153 174 179 158 197 157 154 185 168 159 199 178 163 162 198 195
 183 195 150 199 159 170 195 184 198 198 177 157 170 171 188 194 150 166
 168 155 191 175 198 179 173 169 156 160 195 160 195 166 177 177 153 191
 162 195 165 150 162 157 161 151 188 183 190 178 159 154 157 183 157 181
 160 157 172 153 161 155 192 165 180 191 170 167 150 173 173 152 154 154
 191 156 199 188 181 179 162 164 173 159 178 150 187 167 168 177 155 184
 167 196 193 167 151 169 157 154 157 194 172 194 156 191 194 180 186 186
 152 197 156 151 163 180 166 174 166 158 179 169 176 195 177 188 151 169
 153 187 191 184 189 181 194 172 171 188 151 164 188 180 151 177 187 197
 150 164 167 152 182 186 163 191 155 151 183 197 173 165 187 154 154 172
 181 198 194 181 180 192 193 155 159 184 151 193

### Test data and pre-processing
Randomly generate test data

#### Completely random structure

In [87]:
# Generate completely random "documents"
w_rand = list();

for m in range(M):
    doc = np.random.randint(V,size=N[m])
    w_rand.append(doc)
    
w_rand

[array([ 9, 19,  5, 10,  9, 18, 11, 20, 12,  1, 16, 24, 29,  8, 29, 21, 19,
         0, 12, 25, 12,  9, 18, 22, 10,  9, 26, 20,  3, 23,  4,  9, 29, 25,
        20, 20, 26, 21, 23,  9, 25,  3, 27, 14, 14, 29,  5, 17,  2, 21, 17,
        19,  9, 14, 22,  0,  9,  9,  1, 25, 16, 21,  4, 13, 27, 26, 13,  0,
         0, 18, 16, 29,  2,  2, 12, 29, 20, 20, 20, 13, 15, 17, 16, 15, 14,
         0, 10, 26, 18,  7, 13, 25, 25, 23, 17, 20, 15,  3,  8, 13, 19,  7,
        28, 26,  3,  8, 25, 16, 18, 28, 20, 29, 15,  6, 22, 25, 24, 10, 17,
        10, 26, 29,  3,  0,  4,  2, 11,  7, 13, 16, 18, 22, 24, 25, 28, 27,
        17,  9,  5,  0, 18, 24, 14, 16, 18, 27,  3, 17, 25, 14, 25, 13, 23,
        22, 18,  8, 25,  0,  8, 23, 15, 11,  7, 11, 29,  9,  8]),
 array([ 3, 20, 28,  2,  1, 22, 16,  9, 22, 10, 15,  9, 24,  3, 23, 26, 20,
        24, 15, 17, 11,  3, 24, 25,  2,  5, 26,  6, 23, 22, 13,  3,  0,  9,
         6,  5, 23,  4, 15, 27, 21,  1, 13,  4,  3, 20, 17,  7,  9, 19,  0,
        25, 19, 23, 16

#### Add some structure
Generate data according to the LDA model

In [140]:
alpha_gen = np.array((1,1,10,1,1,20,1,15,1,1))

beta_probs = np.ones((V,k)) + np.array([np.arange(V)%k==i for i in range(k)]).T*19
beta_gen = np.array(list(map(lambda x: np.random.dirichlet(x),beta_probs.T))).T

w_struct = list();
theta = np.empty((M,k))

for m in range(M):
    theta[m,:] = np.random.dirichlet(alpha_gen,1)[0]
    doc = np.array([])
    for n in range(N[m]):
        z_n = np.random.choice(np.arange(k),p=theta[m,:])
        w_n = np.random.choice(np.arange(V),p=beta_gen[:,z_n])
        doc = np.append(doc,w_n)
    w_struct.append(doc)

In [141]:
w_struct

[array([  7.,  15.,  20.,  17.,  23.,  21.,  17.,   2.,   9.,  21.,   3.,
         15.,  27.,  20.,  21.,  17.,  27.,  20.,   5.,   5.,  12.,  22.,
         17.,   7.,   8.,  16.,   7.,  15.,  16.,   2.,   1.,   7.,  15.,
         15.,   2.,   9.,  22.,  17.,  15.,   7.,  17.,  25.,   9.,   7.,
         14.,   7.,   7.,  11.,  25.,  27.,  25.,   8.,  23.,  28.,  17.,
         15.,  17.,  27.,  25.,  17.,  26.,  17.,  15.,  22.,   2.,  25.,
         16.,   9.,   7.,  25.,  22.,  25.,   9.,   7.,  25.,  22.,  25.,
         25.,   6.,  27.,   7.,  15.,   7.,  15.,   0.,  22.,  17.,  28.,
         25.,   4.,  17.,   9.,  17.,  15.,  23.,   8.,   1.,   5.,  27.,
         15.,  25.,  15.,  15.,  25.,   2.,   2.,  25.,  15.,  25.,  25.,
          6.,  29.,  25.,   2.,  15.,  28.,  29.,  25.,  27.,   5.,  21.,
          3.,  20.,  15.,  17.,  25.,   0.,   5.,   2.,  15.,  22.,  15.,
          2.,   5.,   5.,   8.,  25.,  16.,  17.,  24.,  20.,   1.,   2.,
          2.,   1.,   7.,  25.,  12., 

### Initialize parameters $\alpha, \beta, \phi$ and $\gamma$
Random Initialization

In [142]:
alpha = 100*np.random.dirichlet(np.ones(k),1)[0]
beta = np.random.dirichlet(np.ones(V),k).T

phi = np.array([1/k*np.ones([N[m],k]) for m in range(M)])
gamma = np.tile(alpha,(M,1)) + np.tile(N/k,(k,1)).T

In [119]:
alpha

array([ 13.37166453,  32.10371013,  15.66444589,   1.94841544,
         4.24851563,   9.8259512 ,  13.78475835,   4.00556469,
         2.1894614 ,   2.85751273])

In [17]:
beta.shape

(25, 10)

In [18]:
phi.shape

(300,)

In [19]:
gamma.shape

(300, 10)

### Optimize variational parameters $\phi$ and $\gamma$

In [120]:
## Optimize variational parameter phi
def opt_phi(beta,gamma,words,M,N,k):
    for m in range(M):
        for n in range(N[m]):
            for i in range(k):
                phi[m][n,i] = beta[words[m][n],i] * np.exp(digamma(gamma[m,i]) - digamma(np.sum(gamma[m,:])))
            # Normalize across states so phi represents probability over states for each word
            phi[m][n,:] = phi[m][n,:]/sum(phi[m][n,:])
    return phi


## Optimize variational parameter gamma
def opt_gamma(alpha,phi,M):
    gamma = np.tile(alpha,(M,1)) + np.array(list(map(lambda x: np.sum(x,axis=0),phi)))
    return gamma

### Estimate model parameters $\alpha$ and $\beta$

In [121]:
## Optimize beta
def est_beta(phi,words,k,V):
    for j in range (V):
        w_dnj = [np.tile((word == j),(k,1)).T for word in words]
        beta[j,:] = np.sum(np.array(list(map(lambda x: np.sum(x,axis=0),phi*w_dnj))),axis=0)
        
    # Normalize across states so beta represents probability of each word given the state
    for i in range(k):
        beta[:,i] = beta[:,i]/sum(beta[:,i])
        
    return beta


## Optimize alpha
#  (Newton-Raphson method, for a Hessian with special structure)
def est_alpha(alpha,gamma,M,k,nr_max_iters = 1000,tol = 10**-2.0):
    for it in range(nr_max_iters):
        alpha_old = alpha
        
        #  Calculate gradient 
        g = M*(digamma(np.sum(alpha))-digamma(alpha)) + np.sum(digamma(gamma)-np.tile(digamma(np.sum(gamma,axis=1)),(k,1)).T,axis=0)
        #  Calculate Hessian diagonal component
        h = -M*polygamma(1,alpha) 
        #  Calculate Hessian constant component
        z = polygamma(1,np.sum(alpha))
        #  Calculate constant
        c = np.sum(g/h)/(z**(-1.0)+np.sum(h**(-1.0)))

        #  Update alpha
        alpha = alpha - (g-c)/h
        
        #  Check convergence
        if sqrt(mean(square(alpha-alpha_old)))<tol:
            break
        
    return alpha

### Expectation Maximization (EM)

#### Convergence Criterion
The variational inference parameter $\gamma$ contains the topic likelihoods of every document and is thus what is of interest here.

Calculate root-mean-square of the change in $\gamma$

In [122]:
def converged(gamma,gamma_old,convergence):
    print(sqrt(mean(square(gamma-gamma_old))))
    return sqrt(mean(square(gamma-gamma_old))) < convergence

#### Inference by iterative EM
Continue until convergence criterion above met

In [123]:
convergence = 10**(-2.0)
successfully_Converged = False
max_iters = 10**3

for iters in range(max_iters):
    print(iters)
    gamma_old = gamma
    
    ## Expectation step: Update variational parameters
    phi   = opt_phi(beta,gamma,w_struct,M,N,k)
    gamma = opt_gamma(alpha,phi,M)
    
    ## Maximization step: Update model parameters
    beta  = est_beta(phi,w_struct,k,V)
    alpha = est_alpha(alpha,gamma,M,k)
    
    if converged(gamma,gamma_old,convergence):
        successfully_Converged = True
        break

0




8.2745175209
1
17.6787147464
2
15.3391265319
3
15.0322603719
4
15.3472664725
5
15.5387258137
6
15.5544593764
7
15.4536907317
8
15.2827241424
9
15.0852863997
10
14.8617216873
11
14.6267095803
12
14.3918169683
13
14.1471420348
14
13.902444408
15
13.6569391209
16
13.4197462854
17
13.1797630632
18
12.9461532245
19
12.7079225198
20
12.4743381756
21
12.2545043299
22
12.0274490705
23
11.802642078
24
11.5893903899
25
11.3769052862
26
11.1646425006
27
10.9520913518
28
10.7487703381
29
10.5540826107
30
10.35749851
31
10.1586739736
32
9.96727159128
33
9.7828454306
34
9.60498310349
35
9.42331954826
36
9.24762535457
37
9.06757473066
38
8.90296273625
39
8.73338622961
40
8.56868836716
41
8.40860919926
42
8.25290715492
43
8.09136799687
44
7.94386896575
45
7.79011788484
46
7.65001418602
47
7.50329868951
48
7.35989601093
49
7.21964600297
50
7.0823988637
51
6.94801354039
52
6.81635708291
53
6.68730407435
54
6.56073610334
55
6.43654127776
56
6.31461377622
57
6.19485343292
58
6.07716535333
59
5.96145956165

KeyboardInterrupt: 

In [124]:
alpha

array([  858.73857867,  2164.55827097,   957.76606314,   490.35766491,
         478.60723065,   766.42672585,   871.68499509,   437.37751913,
         470.67335317,   527.71691324])

In [125]:
alpha_gen

array([ 1,  1, 10,  1,  1, 20,  1, 15,  1,  1])

In [127]:
beta

array([[  3.75035363e-02,   8.14527419e-04,   4.44501525e-04,
          5.92638792e-03,   2.32266675e-02,   6.89858322e-04,
          2.51804109e-02,   1.08186060e-02,   1.56012903e-02,
          5.82280891e-04],
       [  2.28776907e-02,   3.23706248e-03,   1.86294461e-02,
          1.85904761e-02,   1.94716778e-02,   4.68548876e-02,
          2.92932841e-03,   7.36813209e-04,   4.52946163e-03,
          1.23468160e-02],
       [  1.58567578e-01,   7.22711436e-02,   2.26822583e-02,
          6.34786728e-02,   8.00051168e-02,   9.48517532e-03,
          2.73523477e-02,   1.25347086e-01,   1.02558769e-02,
          4.95752160e-02],
       [  3.38630796e-04,   2.72355224e-04,   3.99534098e-02,
          4.96352457e-05,   4.08810675e-03,   2.41097667e-02,
          6.45553886e-02,   8.57382067e-02,   1.59318315e-02,
          3.07715867e-02],
       [  4.95614387e-03,   8.88289703e-04,   1.53483577e-02,
          2.35849282e-03,   2.49617883e-02,   3.00749117e-02,
          6.30388463e-03

In [128]:
beta_gen

array([[  1.92630833e-01,   7.96838262e-03,   1.26198508e-02,
          3.26341597e-03,   1.15949739e-02,   4.16956977e-03,
          1.82895645e-04,   6.60617169e-03,   2.18922128e-04,
          2.69303124e-03],
       [  7.08867919e-03,   1.64506624e-01,   2.78197712e-02,
          1.63371903e-02,   6.90505078e-04,   6.80210270e-03,
          1.44811359e-02,   3.57795170e-03,   1.15734751e-02,
          5.51124967e-03],
       [  4.46280313e-03,   1.46760963e-02,   2.92470943e-01,
          3.33398939e-02,   1.47566053e-02,   1.04534640e-02,
          1.38507554e-02,   2.96497456e-03,   8.46360181e-03,
          1.28260552e-02],
       [  8.42249493e-03,   5.58480645e-02,   8.41090944e-03,
          2.77449075e-01,   1.23807002e-02,   2.93062574e-02,
          1.39291499e-02,   6.76956941e-03,   3.11834049e-04,
          4.01561442e-03],
       [  3.50137747e-03,   1.98324988e-03,   2.24544533e-03,
          6.41205386e-03,   2.89596600e-01,   3.30308805e-04,
          1.69095142e-02

### Tests 
Testing out syntax and array dimensions

In [130]:
theta

array([[  1.08500023e-01,   7.84225582e-04,   1.98761180e-01, ...,
          2.83771133e-01,   2.55650825e-02,   1.09599201e-03],
       [  3.24849786e-03,   1.59940524e-02,   2.40733290e-01, ...,
          2.25549086e-01,   9.94384115e-03,   7.40190716e-03],
       [  1.18447168e-03,   4.49989733e-03,   2.19290897e-01, ...,
          3.52210388e-01,   1.83315002e-03,   7.02030413e-02],
       ..., 
       [  6.93213988e-03,   9.87572575e-03,   9.86264195e-02, ...,
          2.86190633e-01,   2.85150874e-02,   7.55173007e-02],
       [  6.19163948e-06,   1.58442945e-02,   1.66552888e-01, ...,
          3.64573561e-01,   9.50799311e-03,   4.12626363e-02],
       [  2.92087755e-02,   1.19140646e-02,   2.98645756e-01, ...,
          3.18916499e-01,   1.17548067e-02,   3.07274550e-04]])

In [135]:
gamma/np.sum(gamma,axis=1)[:,None]

array([[ 0.10704713,  0.27019815,  0.11876094, ...,  0.05447204,
         0.05856842,  0.06577053],
       [ 0.10703206,  0.26958463,  0.11930584, ...,  0.0544902 ,
         0.05866049,  0.06582478],
       [ 0.10677928,  0.26974624,  0.1194745 , ...,  0.05434493,
         0.0588217 ,  0.06601595],
       ..., 
       [ 0.10699703,  0.26954884,  0.1196047 , ...,  0.05454484,
         0.05861532,  0.06576293],
       [ 0.10701705,  0.26947679,  0.11944292, ...,  0.05467499,
         0.05875392,  0.06580721],
       [ 0.10720662,  0.26966493,  0.11944798, ...,  0.05462183,
         0.05861889,  0.06582103]])

In [30]:
# Word #11 in document 2 (w_dn)
w_rand[1][10]

17

In [31]:
[doc == 3 for doc in w_rand]

[array([False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False,  True, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False], dtype=bool),
 array([False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False,
        False, False, False, False, False, False, False,  True, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False, False, False, False, False,
        False, False, False, False, False, False, False,

In [33]:
[doc == 3 for doc in w_rand]*w_rand

TypeError: can't multiply sequence by non-int of type 'list'

In [34]:
np.sum(np.array(list(map(lambda x: np.sum(x,axis=0),phi))),axis=0)

array([ 2249.,  2249.,  2249.,  2249.,  2249.,  2249.,  2249.,  2249.,
        2249.,  2249.])