In [1]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation
import pickle
import numpy as np

cv = CountVectorizer()
_dir = "/Users/shinbo/Desktop/metting/LDA/0. data/20news-bydate/newsgroup_preprocessed.pickle"


class LDA_sklearn:
    def __init__(self, path_data, alpha, eta, K):
        # loading data
        self.data = pickle.load(open(path_data, 'rb'))
        np.random.seed(0)
        idx = np.random.choice(len(self.data), 1000, replace=False)
        self.data = [j for i, j in enumerate(self.data) if i in idx]
        self.K = K
        self.alpha = alpha
        self.eta = eta

    def _make_vocab(self):
        self.vocab = []
        for lst in self.data:
            self.vocab += lst
        self.vocab = sorted(list(set(self.vocab)))
        self.w2idx = {j: i for i, j in enumerate(self.vocab)}
        self.idx2w = {val: key for key, val in self.w2idx.items()}
        self.doc2idx = [[self.w2idx[word] for word in doc] for doc in self.data]
        self.data = [' '.join(doc) for doc in self.data]

    def _cv(self):
        self._make_vocab()
        self.cv = CountVectorizer()
        self.df = self.cv.fit_transform(self.data)

    def _train(self):
        self._make_vocab
        self._cv()
        lda = LatentDirichletAllocation(n_components=self.K, 
                                        doc_topic_prior=self.alpha, topic_word_prior=self.eta,
                                        learning_method='batch', max_iter=1000)
        lda.fit(self.df)
        return lda

In [23]:
lda = LDA_sklearn(_dir, 5, 0.1, 10)
result = lda._train() 

In [24]:
lda_lam = [result.components_[i,:] for i in range(10)]

def print_top_words(lam, feature_names, n_top_words):
    for topic_id, topic in enumerate(lam):
        print('\nTopic Nr.%d:' % int(topic_id + 1))
        print(''.join([feature_names[i] + ' ' + str(round(topic[i], 2))
                       + ' | ' for i in topic.argsort()[:-n_top_words - 1:-1]]))
print_top_words(lda_lam, list(lda.cv.get_feature_names()), 10)


Topic Nr.1:
right 204.85 | game 150.91 | people 144.01 | writes 143.6 | would 141.26 | well 137.47 | year 122.36 | article 113.84 | team 105.81 | government 103.5 | 

Topic Nr.2:
say 170.88 | said 165.65 | one 134.63 | go 113.25 | people 112.32 | going 96.85 | day 91.28 | dont 89.25 | time 86.39 | well 81.59 | 

Topic Nr.3:
god 267.02 | think 180.35 | believe 156.47 | one 156.11 | would 147.44 | say 118.72 | people 113.65 | thing 104.36 | dont 103.54 | like 102.93 | 

Topic Nr.4:
key 179.45 | system 159.33 | also 107.91 | keyboard 97.06 | one 83.19 | price 82.14 | pc 75.4 | access 74.23 | de 73.55 | use 68.74 | 

Topic Nr.5:
writes 163.52 | article 141.53 | one 126.43 | israel 98.1 | subject 96.38 | israeli 90.1 | would 89.81 | like 88.32 | number 83.49 | line 76.8 | 

Topic Nr.6:
or 302.1 | do 121.04 | mr 59.5 | font 57.1 | subject 52.99 | help 52.64 | um 44.1 | organization 38.8 | world 37.09 | looking 36.37 | 

Topic Nr.7:
line 407.15 | organization 293.7 | subject 284.25 | nntppos

In [4]:
model = pickle.load(open('lda_model.pickle','rb'))

In [7]:
lda_lam = [model.lam[:,k] for k in range(10)]
def print_top_words(lam, feature_names, n_top_words):
    for topic_id, topic in enumerate(lam):
        print('\nTopic Nr.%d:' % int(topic_id + 1))
        print(''.join([feature_names[i] + ' ' + str(round(topic[i], 2))
                       + ' | ' for i in topic.argsort()[:-n_top_words - 1:-1]]))
print_top_words(lda_lam, list(model.cv.get_feature_names()), 10)


Topic Nr.1:
line 867.85 | subject 739.78 | organization 700.42 | university 409.2 | nntppostinghost 360.69 | distribution 233.85 | anyone 216.1 | please 200.1 | computer 187.51 | new 182.84 | 

Topic Nr.2:
would 634.36 | like 449.24 | good 230.19 | think 227.23 | get 218.78 | people 187.47 | im 174.78 | much 160.13 | thing 143.79 | make 132.07 | 

Topic Nr.3:
question 253.1 | may 211.09 | get 170.17 | group 162.1 | one 161.26 | also 145.68 | find 134.79 | article 127.63 | course 114.1 | answer 107.1 | 

Topic Nr.4:
well 212.85 | year 196.81 | right 161.84 | game 159.1 | point 143.42 | team 137.1 | second 123.1 | state 112.31 | last 100.14 | every 100.1 | 

Topic Nr.5:
writes 504.76 | article 450.11 | organization 311.78 | line 292.35 | subject 273.38 | replyto 122.12 | world 110.81 | david 104.93 | space 95.32 | research 93.1 | 

Topic Nr.6:
window 343.1 | file 308.1 | use 258.33 | system 224.77 | image 216.1 | program 213.1 | jpeg 208.1 | version 190.1 | information 180.1 | available

In [8]:
print(model.perplexity)
print(model._ELBO_history)

[473.5030612993274, 165.96644608870963, 164.58448665312082, 164.31820617487926, 164.2323556293911, 164.19539253285754, 164.17206435577862, 164.15392484743478, 164.14119950109372, 164.12841210838496, 164.12143814942982, 164.1196620103764, 164.11414938972212, 164.11270673994284, 164.1126734963347, 164.11266403426097]
[-877194.2326908677, -727908.0509850737, -726717.375740957, -726486.8041059296, -726412.3866635272, -726380.334127943, -726360.1013935744, -726344.3668504879, -726333.3276133211, -726322.2336887941, -726316.1829453944, -726314.6418912802, -726309.858799906, -726308.6070410116, -726308.5781960557, -726308.569985962]
