In [1]:
%matplotlib inline
import numpy as np
import scipy as sp
import scipy.sparse as spar
import scipy.special as spec
import sys

In [2]:
V = 1000 # nr words in vocabulary
M = 10 # nr documents
K = 5 # nr of topics
alpha = .1 # dirichlet hyperparameter

X = np.random.binomial(1,.1, size=M*V).reshape(M,V)
X = spar.csr_matrix(X, dtype=float)

In [3]:
# For even a reasonable setup like 10K vocabulary, 5K documents and 20 topics, the size of the tensor indexed by
# <document, word, topic> simply explodes to 7.5G. This is why we can't explicitly keep all of $\phi$ in the memory.
# Instead, we iterate over the documents one by one, and accumulate the phi parameter

In [4]:
nr_terms = X.sum(axis=1) 
nr_terms = np.array(nr_terms).squeeze()

In [15]:
# model parameters
beta = np.zeros((K, V)) + 1./V

# variational and temp variables
gamma = np.zeros((K, M)) + alpha + (nr_terms/float(K)) # mth document, i th topic
beta_acc = np.ones((K, V))

for epoch in range(5):
    # E-step
    for m in range(2): # iterate over all documents
        print "new doc"
        phi = np.zeros((K, V), dtype=float) + 1./K

        ixw = (X[m, :] > 0).toarray().squeeze() # an index to words which have appeared in the document
        gammad = gamma[:, m] # slice for the document only once

        for ctr in range(int(100)): 
            # store the previous values
            phi_prev = phi.copy()
            gammad_prev = gammad.copy()
            
            print (np.linalg.norm(phi_prev), np.linalg.norm(gammad_prev))

            # update phi
            # WARN: exp digamma underflows < 1e-3! 
            # TODO: digamma update is wrong! -spec.digamma(np.sum(gammad))
            phi[:, ixw] = ((beta[:, ixw]).T * np.exp(spec.digamma(gammad))).T 
            phi = phi / np.sum(phi, 0) # normalize phi columns

            # update gamma
            gammad = alpha + np.sum(phi, axis=1)

            # check for convergence
            dphinorm = np.linalg.norm(phi - phi_prev, "fro") #/ float(K * V)
            dgammadnorm = np.linalg.norm(gammad - gammad_prev)
            
            if dphinorm < .01 and dgammadnorm < .01:
                break
            else:
                print (dphinorm, dgammadnorm, np.linalg.norm(phi, "fro"), np.linalg.norm(gammad))

        gamma[:, m] = gammad
        beta_acc[:, ixw] += phi[:, ixw]

    # M-step
    # TODO: check for numerical stability
    beta = (beta_acc.T / np.sum(beta_acc, axis=1)).T # normalize beta rows

new doc
(14.14213562373085, 53.889238257744935)
(0.0, 393.54796403996306, 14.14213562373085, 447.43720229770798)
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 48.075461516245475)
(0.0, 399.36174078146252, 14.14213562373085, 447.43720229770798)
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)
new doc
(14.14213562373085, 447.43720229770798)


In [None]:
beta.shape

In [None]:
beta * ixw.toarray()

In [None]:
beta