In [10]:
import numpy as np
from gensim.corpora.dictionary import Dictionary
from gensim.models.ldamodel import LdaModel as ldamodel
from gensim.models.ldamodel import LdaState as ldastate

# Question 1

In [11]:
def lda(vocabulary, beta, alpha, xi):
    '''
    Args:
        vocabulary - list (of length V ) of strings
        beta - topic-word matrix, numpy array of size (k, V )
        alpha - topic distribution parameter vector, of length k
        xi - Poisson parameter (scalar) for document size distribution
    Returns:
        w - list of words (strings) in a document
    '''
    n = np.random.poisson(xi)
    theta = np.random.dirichlet(alpha, xi)
    doc = []
    for i in theta:
        t = beta[int(np.nonzero(np.random.multinomial(1, i))[0])]
        w = vocabulary[int(np.nonzero(np.random.multinomial(1, t))[0])]
        doc.append(w)
    return doc

# Question 2

In [12]:
def init_corpus(vocabulary, beta, alpha, xi, num_docs):
    corpus = []
    for i in range(0, num_docs):
        corpus.append(lda(vocabulary, beta, alpha, xi))
    return corpus

In [13]:
def train_LDA(texts):
    common_dictionary = Dictionary(texts)
    common_corpus = [common_dictionary.doc2bow(text) for text in texts]
    lda = ldamodel(common_corpus, alpha='auto', eta='auto')
    return lda

In [14]:
def model_params(lda, xi, num_topics):
    alpha= lda.alpha
    eta = lda.eta
    theta = np.random.dirichlet(lda.alpha, xi)
    beta = np.random.dirichlet(lda.eta, num_topics)
    return {'alpha': alpha, 'theta': theta, 'eta': eta, 'beta': beta}

In [16]:
vocabulary = ['bass', 'pike', 'deep', 'tuba', 'horn', 'catapult']
beta = np.array([
    [0.4, 0.4, 0.2, 0.0, 0.0, 0.0],
    [0.0, 0.3, 0.1, 0.0, 0.3, 0.3],
    [0.3, 0.0, 0.2, 0.3, 0.2, 0.0]
])
alpha = np.array([1, 3, 8])
xi = 50
num_docs = 50
gen = lda(vocabulary, beta, alpha, xi)
corpus = init_corpus(vocabulary, beta, alpha, xi, num_docs)
model = train_LDA(corpus)
infer = model_params(model, xi, len(beta))
print(infer)

{'alpha': array([0.00954627, 0.00954627, 0.01057715, 0.0105691 , 0.00996073,
       0.00954627, 0.00954627, 0.00954627, 0.00954627, 0.00954627,
       0.00954627, 0.01066995, 0.01141566, 0.01057802, 0.00954627,
       0.00995284, 0.00954627, 0.00995826, 0.00954627, 0.01016877,
       0.00954627, 0.00975517, 0.00974815, 0.01079038, 0.00954627,
       0.00975471, 0.00954627, 0.00954627, 0.01037403, 0.00975229,
       0.00954627, 0.00954627, 0.00954627, 0.01036659, 0.00954627,
       0.00954627, 0.01058628, 0.00954627, 0.01079707, 0.00954627,
       0.00954627, 0.01037332, 0.01058258, 0.01016742, 0.00975493,
       0.00954627, 0.00954627, 0.00954627, 0.00954627, 0.00975155,
       0.00958078, 0.00954627, 0.00954627, 0.00954627, 0.00954627,
       0.00954627, 0.00954627, 0.00975104, 0.00954627, 0.00954627,
       0.00954627, 0.00954627, 0.00954627, 0.00954627, 0.00954627,
       0.00954627, 0.00954627, 0.00954627, 0.01078735, 0.01058296,
       0.00954627, 0.00954627, 0.00954627, 0.0095462