In [410]:
import numpy as np
import pdb
import time
from scipy.special import digamma, polygamma, gammaln
from tqdm import tqdm, trange

def prepare_article(path, word2idx):
    result=dict()
    with open(path, "r") as word_list:
        words = word_list.read().split(' ')
    for word in words:
        if word in word2idx.keys():
            if word2idx[word] in result.keys():
                result[word2idx[word]] += 1
            else:
                result[word2idx[word]] = 1
    with open('new_article_prepared.txt', 'w') as f:
        print(result, file=f)
    

def load_vocab(path):
    word2idx = {}
    idx2word = {}
    with open(path, 'r') as txtfile:
        for i, line in enumerate(txtfile):
            word = line.split('\n')[0]
            word2idx[word] = i
            idx2word[i] =word
    return word2idx, idx2word

def load_article(path, N=200, ths=150):
    article_vec = []  
    article_dict = []
    train_set = []
    with open(path, 'r') as txtfile:
        for line in txtfile:
            vec = []
            dic = {}
            line = line.split('\n')[0].split(',')
            for item in line:
                key = int(item.split(':')[0])
                val = int(item.split(':')[1])
                vec += [key] * val
                dic[key] = val
            if len(vec)>=ths:
                sample = np.random.choice(vec, size=N, replace=True) if len(vec)<N else np.random.choice(vec, size=N, replace=False)
                train_set.append(sample)
            article_vec.append(vec)
            article_dict.append(dic)
    return article_vec, article_dict, np.array(train_set)

def generate_text(gamma, beta, idx2word, n_word=30):
    gamma = gamma / gamma.sum()
    text = []
    for i in range(n_word):
        topic = np.argmax(np.random.multinomial(10, gamma, size=1))
        beta_new = beta[topic,:] / beta[topic,:].sum()
        text.append(idx2word[np.argmax(np.random.multinomial(10, beta_new, size=1))])
    print(', '.join(text))



In [444]:
class MyLDA:
    def __init__(self, data, vocab, K=25):
        self.data = data
        self.M, self.N = self.data.shape
        self.K = K
        self.vocab = vocab
        self.W = len(self.vocab)
        self.alpha = 0.1 * np.ones(self.K) 
        self.beta = np.random.uniform(size=(self.K, self.W)) 
        self.beta = self.beta/self.beta.sum(axis=1,keepdims=True)
        self.gamma = self.alpha + self.N/self.K * np.ones((self.M, self.K))  
        self.phi = 1/self.K * np.ones((self.M, self.N, self.K)) 

    def update_alpha(self, tol=1e-4):
        alpha = self.alpha
        diff = np.Inf
        
        while diff > tol:
            g = self.M * (digamma(alpha.sum()) - digamma(alpha)) + (digamma(self.gamma)-digamma(self.gamma.sum(axis=1,keepdims=True))).sum(axis=0)
            h = 1 / (self.M * polygamma(1, alpha))
            z = 1 / (-1.0 * polygamma(1, alpha.sum()))
            c = np.sum(g*h) / (z + np.sum(h))
            new_alpha = alpha + (g-c)*h
            diff = np.linalg.norm(new_alpha-alpha)
            alpha = new_alpha
        self.alpha = alpha
        print(self.alpha)
        
    def update_beta(self):
        self.beta = np.array([self.phi[(self.data==i),:].sum(axis=0) for i in range(self.W)]).T
        self.beta = self.beta/self.beta.sum(axis=1,keepdims=True)

    def update_gamma(self):
        return self.alpha + self.phi.sum(axis=1) 
    
    def update_phi(self):
        phi = self.phi
        phi = np.einsum('nmk->mnk',(np.einsum('kmn->nmk',((self.beta[:, self.data.flatten()]).reshape((self.K,self.M,self.N))))\
                * np.exp(digamma(self.gamma))))
        return phi / phi.sum(axis=2,keepdims=True)

    
    def Expectation(self, tol=1e-2):
        diff = np.Inf
        while diff > tol:
            prev_phi = self.phi
            prev_gamma = self.gamma
            self.phi = self.update_phi()
            self.gamma = self.update_gamma()
            diff = 0.5/self.M *(np.linalg.norm(self.phi-prev_phi) + np.linalg.norm(self.gamma-prev_gamma))
        
    def Maximization(self):
        self.update_beta()
        self.update_alpha()
        
    def fit(self, max_iter=100, tol=1e-2):
        for i in range(max_iter):
            print('Iteration ',i)
            self.Expectation(tol=tol)
            self.Maximization()
        
    

In [446]:
word2idx, idx2word = load_vocab('nyt_vocab.txt')
article_vec, article_dict, train_set = load_article('nyt_data.txt')
prepare_article('nyt_new_article.txt', word2idx)

LDA = MyLDA(train_set, word2idx)
LDA.fit()

Iteration  0
[0.0994484  0.09525816 0.1005702  0.13460548 0.10303972 0.10357508
 0.10326062 0.0986354  0.13308281 0.11784303 0.13830212 0.11848041
 0.10906249 0.099576   0.10515304 0.111228   0.10248081 0.09195133
 0.11995982 0.10729947 0.10840283 0.10000401 0.10565409 0.12010575
 0.09950006]
Iteration  1
[0.08665046 0.08205951 0.08817025 0.1305116  0.09046347 0.0922167
 0.09221291 0.08582986 0.1245665  0.10770659 0.13244115 0.10880461
 0.09863789 0.08779594 0.09427022 0.10095897 0.09046766 0.07794105
 0.11172991 0.09592214 0.09759375 0.0873661  0.0937726  0.10981142
 0.08674338]
Iteration  2
[0.07372361 0.06982973 0.07501432 0.11278257 0.07694985 0.07926582
 0.07983415 0.0732154  0.10514588 0.09069922 0.11313243 0.09339154
 0.08531318 0.0747353  0.08192658 0.08700612 0.07816177 0.0659633
 0.09705862 0.08171586 0.08456906 0.07483315 0.08044192 0.09388617
 0.07394471]
Iteration  3
[0.0633614  0.05999446 0.06487596 0.097518   0.06659387 0.06864131
 0.06971151 0.06325467 0.08904199 0.0767

[0.01630789 0.01697871 0.01817361 0.02977076 0.02120532 0.02017709
 0.02262093 0.01727647 0.02616617 0.02252788 0.02474648 0.02358637
 0.02188312 0.01550134 0.02275072 0.02315861 0.02063732 0.01429496
 0.02548387 0.02042292 0.02108273 0.01836775 0.02111148 0.0236416
 0.01667221]
Iteration  29
[0.01602393 0.01670893 0.01790332 0.02938653 0.02089054 0.01986164
 0.02231505 0.01697737 0.02582206 0.02220311 0.024387   0.02323062
 0.02155157 0.01521377 0.02246253 0.02283161 0.0203516  0.0140378
 0.02516695 0.02011577 0.02075705 0.01804832 0.02076017 0.02330456
 0.01635392]
Iteration  30
[0.01576145 0.01645245 0.01765677 0.02902886 0.02058813 0.01956285
 0.02202401 0.01669959 0.02550176 0.02189623 0.02405007 0.02289277
 0.02124437 0.01492839 0.02218946 0.02252525 0.02008666 0.01379393
 0.02487518 0.01981937 0.02046155 0.01775911 0.02043555 0.02299747
 0.01604196]
Iteration  31
[0.01551384 0.01621849 0.01742678 0.02868485 0.02029256 0.01929262
 0.02174486 0.01644375 0.02520004 0.02160244 0.023

[0.01227595 0.01283876 0.01399938 0.02340096 0.01669897 0.01541849
 0.01799267 0.01294996 0.02061568 0.01714358 0.01950537 0.01816223
 0.01729598 0.01131743 0.01770015 0.01822835 0.01626516 0.01055139
 0.02020221 0.01581745 0.01671116 0.01364802 0.01599111 0.01858446
 0.01191364]
Iteration  57
[0.01219105 0.01274668 0.01392688 0.02321619 0.01662581 0.0153354
 0.01789742 0.01287984 0.02049541 0.01703461 0.01941995 0.01806642
 0.0172034  0.011245   0.01759596 0.01812987 0.01618614 0.01049024
 0.02010708 0.01573875 0.01663693 0.01356639 0.01589285 0.01848071
 0.0118369 ]
Iteration  58
[0.01211084 0.01266118 0.01385888 0.02303389 0.0165572  0.01525297
 0.01780089 0.01281341 0.02037097 0.01693214 0.01934028 0.0179677
 0.01711627 0.01117704 0.01748381 0.01803873 0.01610914 0.01043158
 0.02001894 0.0156653  0.01656515 0.01348619 0.01580219 0.0183833
 0.01176137]
Iteration  59
[0.01203215 0.01258122 0.01379467 0.02287073 0.01649256 0.01517182
 0.01770358 0.01274433 0.02024332 0.01683051 0.0192

[0.01078673 0.01125886 0.01254292 0.02050243 0.01504775 0.01367944
 0.01578993 0.01140849 0.01826334 0.01478187 0.01786039 0.01608479
 0.01514235 0.01000222 0.015594   0.01610856 0.01466977 0.00914335
 0.01814367 0.01436117 0.01522684 0.0119598  0.01406567 0.01652674
 0.01019235]
Iteration  85
[0.01074694 0.01121882 0.01251161 0.02044056 0.01500004 0.01363187
 0.01571584 0.01136615 0.01819368 0.01471854 0.01782792 0.01603216
 0.01509731 0.00996838 0.01554012 0.0160565  0.01463947 0.00910927
 0.01807856 0.01432637 0.01518256 0.01192341 0.01403182 0.01648271
 0.01014679]
Iteration  86
[0.01070624 0.01117523 0.01248168 0.02037444 0.01495317 0.01358658
 0.01564832 0.01132629 0.01811996 0.01464857 0.0177964  0.01598345
 0.0150506  0.00993366 0.0154902  0.01600822 0.01460201 0.00907689
 0.01801528 0.01429146 0.01513296 0.0118886  0.01399957 0.01643656
 0.01010386]
Iteration  87
[0.0106677  0.01113428 0.01245288 0.02030956 0.01490623 0.01354076
 0.01558652 0.01128868 0.01804777 0.01457856 0.0

In [447]:
idx_5 = np.argsort(-LDA.beta,axis=1)[:,0:10]

In [448]:
words_5 = [[idx2word[t] for t in idx_5[z]] for z in range(25)]

In [449]:
for z in range(25):
    print(', '.join(words_5[z]))

percent, price, market, company, rate, buy, sell, product, business, increase
play, game, team, season, player, thing, coach, win, ask, lot
military, force, war, leader, american, government, country, attack, official, political
government, official, percent, company, plan, large, states, money, agency, report
man, woman, tell, life, place, live, room, write, little, feel
place, small, home, car, lot, thing, area, house, open, old
write, tell, book, thing, life, american, man, question, point, world
company, percent, sale, money, large, sell, pay, business, buy, stock
change, country, television, state, ago, group, national, keep, big, thing
percent, plan, price, pay, official, market, number, increase, deal, company
political, vote, campaign, republican, election, candidate, party, voter, issue, state
art, food, sell, small, add, place, price, large, restaurant, buy
great, life, world, play, man, thing, write, character, present, book
game, play, player, team, win, season, second, coa

In [419]:
LDA.phi[-1].sum(axis=0)

array([4.45614129e-43, 1.63306981e-33, 9.29869164e-42, 3.47065785e-34,
       7.56076071e-34, 1.42134816e-31, 9.90845292e-38, 7.58248545e-36,
       1.20101191e+02, 8.50881314e-45, 5.26998209e-37, 1.04274055e-50,
       8.67576759e-40, 6.59344886e-55, 5.47314162e-45, 3.68310517e-39,
       2.32559074e-65, 2.40421958e-80, 1.04886888e-38, 1.94760976e-41,
       7.98988092e+01, 1.14633635e-38, 4.75807997e-40, 3.50657616e-34,
       5.91723634e-50])

In [422]:
np.argsort(-LDA.phi[-1].sum(axis=0))[0:2]

array([ 8, 20])

In [425]:
generate_text(LDA.gamma[-1], LDA.beta, idx2word, n_word=30)

mrs, problem, late, young, put, thing, big, system, world, woman, problem, consider, president, player, woman, offer, woman, home, shot, point, report, company, face, away, man, group, public, percent, home, child
