## Latent Dirichlet Allocation
This is a demostration of LDA topic model using Gibbs sampling on a "perfect dataset"   
Thanks to the clear [tutorial](https://www.cnblogs.com/pinard/p/6831308.html) provided by Pinard Liu  
Author: kUNQI jIANG   
Date: 2019/1/22  

### Corpus generation
As Gibbs sampling in LDA essentially based on bag-of-words so the order of words does not matter, I use completely seperated wordset of different topic to generate pure topic documents as corpus. This is the extreme case where words and topics will be completely clustered after LDA as we can see in the result. While in real word, a word exist in different topics, and a document can cover multi-topics.

In [1]:
food_set = ["broccoli","banana","spinach","smoothie","breakfast","ham","cream","eat","vegetable","dinner","lunch",
            "apple","peach","pork","beef","rice","noodle","chicken","KFC","restaurant","cream","tea","pan","beacon"]
animal_set = ["dog","cat","fish","chinchilla","kitten","cute","hamster","munching","bird","elephant","monkey","zoo",
              "zoology","pig","piggy","duck","mice","micky","tiger","lion","horse","dragon","panda","bee","rabbit"]
soccer_set = ["football","pitch","play","player","cup","ballon","messi","ronald","manU","liverpool","chelase","ozil",
              "practice","hard","dream","stadium","fast","speed","strong","move","shot","attack","defense","win"]

In [2]:
import numpy as np
def generate(topic_set):
    sent = np.random.choice(topic_set,10)
    return " ".join(sent)

In [3]:
topics_set = [food_set,animal_set,soccer_set]
corpus = []
for i in range(100):
    corpus.append(generate(topics_set[0]).split())
    corpus.append(generate(topics_set[1]).split())
    corpus.append(generate(topics_set[2]).split())

In [4]:
import numpy as np

all_words = [word for document in corpus for word in document]
vocab = set(all_words)
num_docs = len(corpus)
num_words = len(vocab)
word2id = {w:i for i,w in enumerate(vocab)}
id2word = {i:w for i,w in enumerate(vocab)}

In [5]:
# model 3 latent topics 
num_topics = 3
# Dirichlet prior
alpha = np.ones([num_topics])
#ita = term_freq
ita = 0.1 * np.ones([num_words])

### Random assignment
At the start randomly assign topic to each word in each document

In [6]:
topic_assignments = []
docs_topics = np.zeros([num_docs,num_topics]) # counts of topic assignments of each word in each doc
words_topics = np.zeros([num_words,num_topics]) # counts of topic distributes of each word over all doc
topics_words = np.zeros([num_topics,num_words]) # counts of word distributes of each topic over all doc

for d,document in enumerate(corpus):
    theta = np.random.dirichlet(alpha, 1)[0]
    doc_topics = []
    for n,word in enumerate(document):
        sample = np.random.multinomial(1, theta, size=1)[0]
        topic = list(sample).index(1)
        doc_topics.append(topic)
        docs_topics[d,topic] += 1
        words_topics[word2id[word],topic] += 1
        topics_words[topic,word2id[word]] += 1
    topic_assignments.append(doc_topics)
    

### Gibbs Sampling

In [7]:
def Gibbs_sampling(d,word_id,words_topics,docs_topics,topics_words,alpha,ita):
    
    topic_probs = (docs_topics[d] + alpha) / np.sum(docs_topics[d] + alpha)
    word_sum = np.sum(topics_words + ita, axis = 1)
    word_probs = (words_topics[word_id] + ita[word_id]) / word_sum
    # posterior probs
    probs = topic_probs * word_probs
    # normalize
    sample_probs = probs / np.sum(probs)
    #print(sample_probs)
    # sample new topic for current word
    new_topic = list(np.random.multinomial(1, sample_probs, size=1)[0]).index(1)
    return new_topic

In [8]:
num_iterations = 9
for j in range(num_iterations):
    for d in range(len(corpus)):
        document = corpus[d]
        for n in range(len(document)):
            word = document[n]
            word_id = word2id[word]
            topic = topic_assignments[d][n]
            # exclude current word and topic
            docs_topics[d][topic] -= 1
            topics_words[topic][word_id] -=1
            words_topics[word_id,topic] -= 1
            new_topic = Gibbs_sampling(d,word_id,words_topics,docs_topics,topics_words,alpha,ita)
            # update topic and word state
            docs_topics[d][new_topic] += 1
            topics_words[new_topic][word_id] += 1
            words_topics[word_id,new_topic] += 1
            topic_assignments[d][n] = new_topic

### Evaluation

In [9]:
docs_topics

array([[10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.,  0.],
       [10.,  0.,  0.],
       [ 0.,  0., 10.],
       [ 0., 10.

In [10]:
import matplotlib.pyplot as plt
for i,state in enumerate(topics_words):
    # sorted descending word frequence within each topic
    topic_id_freq = sorted(range(len(state)), key=lambda k: state[k], reverse=True)
    topic_word_freq = [id2word[i] for i in topic_id_freq]
    print("Topic: ", i)
    print(topic_word_freq)

Topic:  0
['cream', 'rice', 'spinach', 'dinner', 'beef', 'smoothie', 'vegetable', 'restaurant', 'chicken', 'peach', 'banana', 'tea', 'broccoli', 'beacon', 'apple', 'lunch', 'pan', 'noodle', 'pork', 'ham', 'breakfast', 'eat', 'KFC', 'messi', 'strong', 'micky', 'tiger', 'win', 'speed', 'football', 'liverpool', 'move', 'ballon', 'cat', 'fast', 'manU', 'horse', 'pig', 'lion', 'shot', 'hard', 'pitch', 'mice', 'ozil', 'hamster', 'elephant', 'chinchilla', 'zoo', 'stadium', 'panda', 'monkey', 'dragon', 'piggy', 'bird', 'zoology', 'rabbit', 'chelase', 'dog', 'fish', 'dream', 'defense', 'player', 'munching', 'kitten', 'attack', 'cute', 'cup', 'play', 'practice', 'bee', 'duck', 'ronald']
Topic:  1
['ozil', 'strong', 'ronald', 'liverpool', 'play', 'pitch', 'football', 'defense', 'cup', 'win', 'stadium', 'messi', 'fast', 'chelase', 'player', 'practice', 'ballon', 'shot', 'dream', 'move', 'attack', 'speed', 'manU', 'hard', 'pork', 'micky', 'tiger', 'banana', 'smoothie', 'beef', 'breakfast', 'beacon'

In [11]:
topics_words

array([[ 0., 38.,  0.,  0.,  0., 40., 47., 51., 36.,  0.,  0.,  0., 39.,
         0.,  0., 53.,  0., 47., 35.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
        42.,  0., 52.,  0.,  0.,  0., 39., 39.,  0., 40., 78.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0., 29.,  0.,  0., 39.,  0., 46.,  0.,
         0.,  0.,  0., 39., 52.,  0., 42.,  0.,  0., 37.,  0., 40.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.],
       [41.,  0., 54.,  0.,  0.,  0.,  0.,  0.,  0., 45., 30., 47.,  0.,
        52., 35.,  0., 37.,  0.,  0.,  0., 39., 30.,  0.,  0.,  0., 37.,
         0., 28.,  0., 48.,  0., 55.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0., 42.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0., 39.,
         0.,  0., 37.,  0.,  0., 47.,  0., 39.,  0.,  0.,  0.,  0., 31.,
         0., 47., 49., 38.,  0.,  0., 53.],
       [ 0.,  0.,  0., 41., 40.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0., 47.,  0.,  0., 38., 32., 49.,  0.,
         0.,  0.,  0.,  0., 45.,  0.

In [12]:
for i in range(len(words_topics)):
    print(words_topics[i],id2word[i])

[ 0. 41.  0.] messi
[38.  0.  0.] pork
[ 0. 54.  0.] strong
[ 0.  0. 41.] micky
[ 0.  0. 40.] tiger
[40.  0.  0.] banana
[47.  0.  0.] smoothie
[51.  0.  0.] beef
[36.  0.  0.] breakfast
[ 0. 45.  0.] win
[ 0. 30.  0.] speed
[ 0. 47.  0.] football
[39.  0.  0.] beacon
[ 0. 52.  0.] liverpool
[ 0. 35.  0.] move
[53.  0.  0.] rice
[ 0. 37.  0.] ballon
[47.  0.  0.] vegetable
[35.  0.  0.] eat
[ 0.  0. 47.] cat
[ 0. 39.  0.] fast
[ 0. 30.  0.] manU
[ 0.  0. 38.] horse
[ 0.  0. 32.] pig
[ 0.  0. 49.] lion
[ 0. 37.  0.] shot
[42.  0.  0.] chicken
[ 0. 28.  0.] hard
[52.  0.  0.] spinach
[ 0. 48.  0.] pitch
[ 0.  0. 45.] mice
[ 0. 55.  0.] ozil
[39.  0.  0.] apple
[39.  0.  0.] lunch
[ 0.  0. 24.] hamster
[40.  0.  0.] tea
[78.  0.  0.] cream
[ 0.  0. 41.] elephant
[ 0.  0. 44.] chinchilla
[ 0.  0. 45.] zoo
[ 0. 42.  0.] stadium
[ 0.  0. 42.] panda
[ 0.  0. 40.] monkey
[ 0.  0. 48.] dragon
[ 0.  0. 35.] piggy
[29.  0.  0.] KFC
[ 0.  0. 46.] bird
[ 0.  0. 33.] zoology
[39.  0.  0.] pan
[ 0.  

### Comparison
Justify my result with gensim LDA model

In [13]:
import gensim
from gensim import corpora
text_data = corpus
dictionary = corpora.Dictionary(text_data)
id_corpus = [dictionary.doc2bow(text) for text in text_data]

ldamodel = gensim.models.ldamodel.LdaModel(id_corpus, num_topics = num_topics, id2word=dictionary, passes=12)
#ldamodel.save('model5.gensim')
topics = ldamodel.print_topics(num_words=num_words)
for topic in topics:
    print(topic)

(0, '0.076*"cream" + 0.051*"rice" + 0.050*"dinner" + 0.050*"spinach" + 0.049*"beef" + 0.046*"vegetable" + 0.046*"smoothie" + 0.045*"restaurant" + 0.041*"chicken" + 0.041*"peach" + 0.039*"tea" + 0.039*"broccoli" + 0.039*"banana" + 0.038*"lunch" + 0.038*"pan" + 0.038*"beacon" + 0.038*"noodle" + 0.038*"apple" + 0.037*"pork" + 0.036*"ham" + 0.035*"breakfast" + 0.034*"eat" + 0.028*"KFC" + 0.001*"rabbit" + 0.001*"bird" + 0.001*"bee" + 0.001*"dragon" + 0.001*"pig" + 0.001*"monkey" + 0.001*"kitten" + 0.001*"elephant" + 0.001*"duck" + 0.001*"chinchilla" + 0.001*"lion" + 0.001*"cat" + 0.001*"zoo" + 0.001*"tiger" + 0.001*"dog" + 0.001*"panda" + 0.001*"munching" + 0.001*"piggy" + 0.001*"mice" + 0.001*"fish" + 0.001*"hamster" + 0.001*"zoology" + 0.001*"micky" + 0.001*"cute" + 0.000*"horse" + 0.000*"ronald" + 0.000*"stadium" + 0.000*"football" + 0.000*"attack" + 0.000*"pitch" + 0.000*"shot" + 0.000*"messi" + 0.000*"defense" + 0.000*"play" + 0.000*"speed" + 0.000*"ballon" + 0.000*"cup" + 0.000*"fast"