In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
basedir = '../'
sys.path.append(basedir)

import numpy as np
from numpy.random import RandomState
import pandas as pd
from IPython.display import display

from synth_data import HldaDataGenerator
from hlda import NCRPNode

Synthetic data test for hierarchical LDA inference.

# 1. Generate Vocab

In [140]:
n_rows = 10
n_cols = 10
vocab_mat = np.zeros((n_rows, n_cols), dtype=np.object)
word_count = 0
for i in range(n_rows):
    for j in range(n_cols):
        vocab_mat[i, j] = 'w%s' % word_count
        word_count += 1
        
print vocab_mat

[['w0' 'w1' 'w2' 'w3' 'w4' 'w5' 'w6' 'w7' 'w8' 'w9']
 ['w10' 'w11' 'w12' 'w13' 'w14' 'w15' 'w16' 'w17' 'w18' 'w19']
 ['w20' 'w21' 'w22' 'w23' 'w24' 'w25' 'w26' 'w27' 'w28' 'w29']
 ['w30' 'w31' 'w32' 'w33' 'w34' 'w35' 'w36' 'w37' 'w38' 'w39']
 ['w40' 'w41' 'w42' 'w43' 'w44' 'w45' 'w46' 'w47' 'w48' 'w49']
 ['w50' 'w51' 'w52' 'w53' 'w54' 'w55' 'w56' 'w57' 'w58' 'w59']
 ['w60' 'w61' 'w62' 'w63' 'w64' 'w65' 'w66' 'w67' 'w68' 'w69']
 ['w70' 'w71' 'w72' 'w73' 'w74' 'w75' 'w76' 'w77' 'w78' 'w79']
 ['w80' 'w81' 'w82' 'w83' 'w84' 'w85' 'w86' 'w87' 'w88' 'w89']
 ['w90' 'w91' 'w92' 'w93' 'w94' 'w95' 'w96' 'w97' 'w98' 'w99']]


In [141]:
vocab = vocab_mat.flatten().tolist()
print vocab

['w0', 'w1', 'w2', 'w3', 'w4', 'w5', 'w6', 'w7', 'w8', 'w9', 'w10', 'w11', 'w12', 'w13', 'w14', 'w15', 'w16', 'w17', 'w18', 'w19', 'w20', 'w21', 'w22', 'w23', 'w24', 'w25', 'w26', 'w27', 'w28', 'w29', 'w30', 'w31', 'w32', 'w33', 'w34', 'w35', 'w36', 'w37', 'w38', 'w39', 'w40', 'w41', 'w42', 'w43', 'w44', 'w45', 'w46', 'w47', 'w48', 'w49', 'w50', 'w51', 'w52', 'w53', 'w54', 'w55', 'w56', 'w57', 'w58', 'w59', 'w60', 'w61', 'w62', 'w63', 'w64', 'w65', 'w66', 'w67', 'w68', 'w69', 'w70', 'w71', 'w72', 'w73', 'w74', 'w75', 'w76', 'w77', 'w78', 'w79', 'w80', 'w81', 'w82', 'w83', 'w84', 'w85', 'w86', 'w87', 'w88', 'w89', 'w90', 'w91', 'w92', 'w93', 'w94', 'w95', 'w96', 'w97', 'w98', 'w99']


# 2. Assign Documents to Tree

In [86]:
NCRPNode.total_nodes = 0
NCRPNode.last_node_id = 0
num_levels = 3
gamma = 1
num_docs = 100

root_node = NCRPNode(num_levels, vocab)
document_path = {}
unique_nodes = set()
unique_nodes.add(root_node)
for d in range(num_docs):

    # populate nodes into the path of this document
    path = np.zeros(num_levels, dtype=np.object)
    path[0] = root_node
    root_node.customers += 1 # always add to the root node first
    for level in range(1, num_levels):
        # at each level, a node is selected by its parent node based on the CRP prior
        parent_node = path[level-1]
        level_node = parent_node.select(gamma)
        level_node.customers += 1
        path[level] = level_node
        unique_nodes.add(level_node)

    # set the leaf node for this document                 
    document_path[d] = path
    
unique_nodes = sorted(unique_nodes, key=lambda x: x.node_id)
print len(unique_nodes)
    
def print_node(node, indent, node_topic):
    out = '    ' * indent
    out += 'node %d (level=%d, documents=%d): ' % (node.node_id, node.level, node.customers)
    if node in node_topic:
        probs, words = node_topic[node]
        out += ' '.join(words)
    print out        
    for child in node.children:
        print_node(child, indent+1, node_topic)        

node_topic = {}
print_node(root_node, 0, node_topic)

11
node 0 (level=0, documents=100):		
    node 1 (level=1, documents=71):		
        node 2 (level=2, documents=62):		
        node 8 (level=2, documents=7):		
        node 10 (level=2, documents=2):		
    node 3 (level=1, documents=29):		
        node 4 (level=2, documents=3):		
        node 5 (level=2, documents=12):		
        node 6 (level=2, documents=9):		
        node 7 (level=2, documents=3):		
        node 9 (level=2, documents=2):		


# 3. Assign Each Node Along the Tree to a Topic

In [87]:
def get_words(vocab_mat, alpha, pos, dim):

    if dim == 'row':
        words = vocab_mat[pos]
    elif dim == 'col':
        words = vocab_mat[:, pos]
    
    k = len(words)
    alpha = [alpha] * k
    probs = np.random.dirichlet(alpha)
    return probs, words
    
pos = 0
probs, words = get_words(vocab_mat, eta, pos, 'col')
print probs
print words
print np.sum(probs)

[ 0.15371927  0.04026764  0.20572376  0.0547151   0.12450964  0.00613161
  0.00932277  0.05206163  0.11992228  0.23362631]
['word_0' 'word_10' 'word_20' 'word_30' 'word_40' 'word_50' 'word_60'
 'word_70' 'word_80' 'word_90']
1.0


In [143]:
eta = 1
node_topic = {}
node_topic[unique_nodes[0]] = get_words(vocab_mat, eta, 0, 'row') 
node_topic[unique_nodes[1]] = get_words(vocab_mat, eta, 1, 'row') 
node_topic[unique_nodes[2]] = get_words(vocab_mat, eta, 2, 'row') 
node_topic[unique_nodes[3]] = get_words(vocab_mat, eta, 3, 'row') 
node_topic[unique_nodes[4]] = get_words(vocab_mat, eta, 4, 'row') 
node_topic[unique_nodes[5]] = get_words(vocab_mat, eta, 5, 'row') 
node_topic[unique_nodes[6]] = get_words(vocab_mat, eta, 6, 'row') 
node_topic[unique_nodes[7]] = get_words(vocab_mat, eta, 7, 'row') 
node_topic[unique_nodes[8]] = get_words(vocab_mat, eta, 8, 'row') 
node_topic[unique_nodes[9]] = get_words(vocab_mat, eta, 9, 'row') 
node_topic[unique_nodes[10]] = get_words(vocab_mat, eta, 1, 'col') 
print len(node_topic)

11


In [144]:
print_node(root_node, 0, node_topic)

node 0 (level=0, documents=100): 	w0 w1 w2 w3 w4 w5 w6 w7 w8 w9
    node 1 (level=1, documents=71): 	w10 w11 w12 w13 w14 w15 w16 w17 w18 w19
        node 2 (level=2, documents=62): 	w20 w21 w22 w23 w24 w25 w26 w27 w28 w29
        node 8 (level=2, documents=7): 	w80 w81 w82 w83 w84 w85 w86 w87 w88 w89
        node 10 (level=2, documents=2): 	w1 w11 w21 w31 w41 w51 w61 w71 w81 w91
    node 3 (level=1, documents=29): 	w30 w31 w32 w33 w34 w35 w36 w37 w38 w39
        node 4 (level=2, documents=3): 	w40 w41 w42 w43 w44 w45 w46 w47 w48 w49
        node 5 (level=2, documents=12): 	w50 w51 w52 w53 w54 w55 w56 w57 w58 w59
        node 6 (level=2, documents=9): 	w60 w61 w62 w63 w64 w65 w66 w67 w68 w69
        node 7 (level=2, documents=3): 	w70 w71 w72 w73 w74 w75 w76 w77 w78 w79
        node 9 (level=2, documents=2): 	w90 w91 w92 w93 w94 w95 w96 w97 w98 w99


# 4. Generate Words in a Document Based on Its Path

In [145]:
def generate_document(topics, theta, doc_len):

    # for every word in the vocab for this document
    doc = []
    for n in range(doc_len):

        # sample a new topic index    
        k = np.random.multinomial(1, theta).argmax()

        # sample a new word from the word distribution of topic k
        probs, words = topics[k]
        w = np.random.multinomial(1, probs).argmax()
        doc_word = words[w]

        doc.append(doc_word)

    return doc

In [146]:
corpus = []
alpha = [2.0, 1.0, 0.5]
doc_len = 250
for d in range(num_docs):
    path = document_path[d]
    topics = [node_topic[node] for node in path]
    theta = np.random.mtrand.dirichlet(alpha)
    doc = generate_document(topics, theta, doc_len)
    corpus.append(doc)

In [174]:
import os

outdir = '/Users/joewandy/Dropbox/Analysis/hLDA/data/synthetic/'
for d in range(len(corpus)):
    doc = corpus[d]
    file_name = 'doc_%d.txt' % d
    file_path = os.path.join(outdir, file_name)
    with open(file_path, 'w') as f:
        f.write("%s\n" % ' '.join(doc))

# 5. Run hLDA

In [147]:
print len(vocab), len(corpus), len(corpus[0])

100 100 250


convert corpus words into indices

In [148]:
new_corpus = []
for doc in corpus:
    new_doc = []
    for word in doc:
        word_idx = vocab.index(word)
        new_doc.append(word_idx)
    new_corpus.append(new_doc)

In [151]:
print len(vocab), len(new_corpus)
print corpus[0]
print new_corpus[0]

100 100
['w8', 'w3', 'w8', 'w3', 'w7', 'w8', 'w1', 'w3', 'w5', 'w1', 'w4', 'w1', 'w0', 'w1', 'w9', 'w10', 'w1', 'w8', 'w1', 'w2', 'w6', 'w4', 'w8', 'w1', 'w1', 'w0', 'w6', 'w2', 'w3', 'w1', 'w6', 'w1', 'w7', 'w1', 'w3', 'w3', 'w7', 'w6', 'w6', 'w6', 'w7', 'w16', 'w9', 'w6', 'w6', 'w7', 'w3', 'w6', 'w3', 'w3', 'w2', 'w8', 'w3', 'w7', 'w8', 'w6', 'w19', 'w1', 'w6', 'w10', 'w4', 'w3', 'w4', 'w1', 'w1', 'w3', 'w1', 'w1', 'w6', 'w2', 'w2', 'w7', 'w1', 'w6', 'w6', 'w1', 'w15', 'w6', 'w19', 'w9', 'w1', 'w16', 'w7', 'w1', 'w7', 'w5', 'w4', 'w6', 'w12', 'w7', 'w7', 'w7', 'w1', 'w8', 'w6', 'w6', 'w1', 'w3', 'w1', 'w1', 'w6', 'w6', 'w1', 'w2', 'w5', 'w10', 'w18', 'w19', 'w3', 'w8', 'w1', 'w10', 'w3', 'w1', 'w7', 'w1', 'w6', 'w6', 'w8', 'w3', 'w1', 'w6', 'w7', 'w0', 'w6', 'w6', 'w3', 'w3', 'w18', 'w14', 'w11', 'w2', 'w10', 'w6', 'w7', 'w3', 'w3', 'w7', 'w1', 'w6', 'w1', 'w3', 'w2', 'w2', 'w4', 'w2', 'w7', 'w10', 'w7', 'w8', 'w22', 'w14', 'w10', 'w6', 'w7', 'w3', 'w6', 'w19', 'w9', 'w7', 'w1', 'w3'

In [165]:
from hlda import HierarchicalLDA

In [169]:
print alpha, gamma, eta

[2.0, 1.0, 0.5] 1 1


In [175]:
n_samples = 1000
hlda = HierarchicalLDA(new_corpus, vocab, alpha=1, gamma=1.0, eta=1.0, num_levels=3)
hlda.estimate(n_samples, display_topics=50, n_words=10, with_weights=False)

HierarchicalLDA sampling
..................................................
topic 0 (level=0, total_words=14193, documents=100): w1, w6, w3, w7, w8, w4, w2, w9, w0, w5, 
    topic 1 (level=1, total_words=289, documents=24): w77, w71, w76, w36, w74, w72, w30, w73, w6, w33, 
        topic 2 (level=2, total_words=1848, documents=22): w37, w36, w38, w30, w34, w31, w7, w33, w1, w52, 
        topic 18 (level=2, total_words=248, documents=2): w88, w10, w82, w83, w19, w14, w15, w87, w84, w16, 
    topic 3 (level=1, total_words=12, documents=17): w24, w3, w6, w20, w23, w29, w36, w28, w30, w31, 
        topic 14 (level=2, total_words=1783, documents=17): w10, w27, w14, w19, w26, w16, w11, w28, w15, w22, 
    topic 5 (level=1, total_words=1049, documents=38): w10, w14, w19, w16, w15, w27, w22, w28, w11, w1, 
        topic 7 (level=2, total_words=134, documents=4): w10, w14, w11, w19, w15, w16, w26, w3, w9, w7, 
        topic 15 (level=2, total_words=3111, documents=34): w10, w14, w19, w16, w11, w

In [178]:
n_samples = 1000
hlda = HierarchicalLDA(new_corpus, vocab, alpha=1, gamma=1.0, eta=1.0, num_levels=3)
hlda.estimate(n_samples, display_topics=50, n_words=10, with_weights=False)

HierarchicalLDA sampling
..................................................
topic 0 (level=0, total_words=14389, documents=100): w1, w6, w3, w7, w8, w4, w2, w9, w0, w5, 
    topic 1 (level=1, total_words=4473, documents=86): w10, w14, w19, w16, w15, w11, w12, w1, w17, w13, 
        topic 2 (level=2, total_words=2588, documents=53): w27, w26, w28, w22, w25, w24, w29, w21, w23, w20, 
        topic 6 (level=2, total_words=61, documents=4): w91, w71, w61, w81, w0, w10, w1, w31, w4, w11, 
        topic 7 (level=2, total_words=198, documents=6): w88, w83, w82, w84, w87, w81, w6, w7, w86, w89, 
        topic 9 (level=2, total_words=1703, documents=18): w37, w36, w38, w30, w34, w31, w33, w52, w56, w1, 
        topic 26 (level=2, total_words=4, documents=1): w6, w23, w7, w99, w37, w28, w29, w30, w31, w32, 
        topic 30 (level=2, total_words=84, documents=2): w48, w43, w47, w38, w3, w40, w45, w31, w34, w37, 
        topic 32 (level=2, total_words=181, documents=2): w37, w36, w38, w30, w97, w

KeyboardInterrupt: 