In [3]:
# 在colab中，切换到工作目录
import os
if "news.txt" not in os.listdir():
    os.chdir('drive/MyDrive/LDA')
assert "news.txt" in os.listdir()

In [4]:
import numpy as np
import time

marks = ",.!~`+-_=—“”‘’()[]\n"
stop_words = ["i","the",'in',"to","of","and","or","for","on","that","he",
"she","it","is","was","were","his","mr","with","you","from","a","an","not",
"at","but","as","are","be","has","have","we","who","they","by","had","would",
"its","their","which","this","said","about","my","been","her","after","one",
"will","there","ms","when","what","after","new","more","if","also","than",
"him","them","so","me","some","other","all","can","could"]

def initialize_dataset(filename):
  word_ids = dict()
  ids_word = dict()
  docs = []
  current_doc = []
  current_word = 0
  for line in open(filename,'r',encoding='utf-8').readlines():
    # 统一小写
    new_line = line.lower()
    # 去除标点
    for mk in marks:
      new_line = new_line.replace(mk,' ')
    # 去除stop_words
    new_line = new_line.split()
    words = [w for w in new_line if w not in stop_words]
    for w in words:
      if w in word_ids:
        current_doc.append(word_ids[w])
      else:
        current_doc.append(current_word)
        word_ids[w] = current_word
        ids_word[current_word] = w
        current_word = current_word + 1
    
    docs.append(current_doc);
    current_doc = []
  return docs, word_ids, ids_word

def initialize(docs):
	#Initialization with Online Gibbs Sampling
	for d, doc in enumerate(docs):
		tw = []
		for word in doc:
			p_t = np.divide(np.multiply(ntd[:,d], nwt[word,:]), nt)
			t = np.random.multinomial(1, p_t / p_t.sum()).argmax()
			tw.append(t)
			ntd[t][d] = ntd[t][d] + 1
			nwt[word,t] = nwt[word,t] + 1
			nt[t] = nt[t] + 1
		twd.append(np.array(tw))

def gibbs_iteration(docs):
	#Collapsed Gibbs Sampling Iteration
	for d, doc in enumerate(docs):
		for w, word in enumerate(doc):
			# Decrement counts for old topic of the word
			t = twd[d][w]
			ntd[t][d] = ntd[t][d] - 1
			nwt[word,t] = nwt[word,t] - 1
			nt[t] = nt[t] - 1
			
			# Sample new topic 
			p_t = np.divide(np.multiply(ntd[:,d], nwt[word,:]), nt)
			t = np.random.multinomial(1, p_t / p_t.sum()).argmax()
			
			# Increment counts for new topic of the word
			twd[d][w] = t 
			ntd[t][d] = ntd[t][d] + 1
			nwt[word,t] = nwt[word,t] + 1
			nt[t] = nt[t] + 1

def perplexity():
	nd = np.sum(ntd, 0)
	n = 0
	ll = 0.0
	for d, doc in enumerate(docs):
		for word in enumerate(doc):
			ll = ll + np.log(((nwt[word,:]/nt) * (ntd[:,d]/nd[d])).sum())
			n = n + 1
	return np.exp(ll/(-n))

def topwords(topic,num):
  ids_top_words = np.argsort(-nwt[:,topic])
  top_words = []
  top_freq = []
  for j in ids_top_words:
    top_words.append(ids_word[j])
    top_freq.append(nwt[:,topic][j])
  top_freq /= sum(top_freq)
  return top_words[:num],top_freq[:num]


print(time.strftime('%X'), "Reading Data")
docs, word_ids, ids_word = initialize_dataset("news_easy.txt")

alpha = 5 # 超参数
beta = 0.1 # 超参数
n_iter = 10 # 吉布斯采样循环次数，一次约2分钟，原版10
for k in [5,10,20]:
  print("=====主题数：%d====="%k)
  # topic of word w in doc d
  twd = []
  # number of words of topic t in doc d
  ntd = np.zeros((k, len(docs))) + alpha
  # number of times word w is in topic t
  nwt = np.zeros((len(word_ids), k)) + beta
  # number of words in topic t
  nt = np.zeros(k) + (len(word_ids) * beta)

  print(time.strftime('%X'), "Initialize Gibbs Sampling")
  initialize(docs)
  print(time.strftime('%X'), "Initialization Complete")

  for i in range(n_iter):
    gibbs_iteration(docs)
    print(time.strftime('%X'), "Iteration: ", i, " Completed", " Perplexity: ", perplexity())

  for t in range(k):
    words,freqs = topwords(t,10) # 前10项
    wf_dict = [(w,f) for w,f in zip(words,freqs)]
    print("主题%d:"%(t),wf_dict)

08:33:32 Reading Data
=====主题数：5=====
08:33:32 Initialize Gibbs Sampling
08:33:32 Initialization Complete
08:33:32 Iteration:  0  Completed  Perplexity:  451.2561790303506
08:33:32 Iteration:  1  Completed  Perplexity:  440.94126978520467
08:33:32 Iteration:  2  Completed  Perplexity:  429.5312014588302
08:33:32 Iteration:  3  Completed  Perplexity:  427.3143386032163
08:33:32 Iteration:  4  Completed  Perplexity:  421.65356629617
08:33:32 Iteration:  5  Completed  Perplexity:  414.6360343227386
08:33:32 Iteration:  6  Completed  Perplexity:  410.0997056993607
08:33:32 Iteration:  7  Completed  Perplexity:  399.55707093867045
08:33:33 Iteration:  8  Completed  Perplexity:  386.7895042314187
08:33:33 Iteration:  9  Completed  Perplexity:  388.0599594578569
主题0: [('just', 0.02051376824986042), ('re', 0.014969506560708954), ('out', 0.014969506560708954), ('nationals', 0.011273332101274644), ('kalisz', 0.009425244871557489), ('200', 0.009425244871557489), ('second', 0.009425244871557489), 