In [1]:
import os
import pdb
import _pickle as cPickle

import numpy as np
import tensorflow as tf

from collections import defaultdict
from scipy.special import gammaln

# data 

In [2]:
def del_all_flags(FLAGS):
    flags_dict = FLAGS._flags()    
    keys_list = [keys for keys in flags_dict]    
    for keys in keys_list:
        FLAGS.__delattr__(keys)

del_all_flags(tf.flags.FLAGS)

flags = tf.app.flags

flags.DEFINE_string('data_path', 'data/synthetic/instances_ncrp.pkl', 'path of data')

flags.DEFINE_string('f', '', 'kernel')

config = flags.FLAGS

In [3]:
instances_train, instances_dev, instances_test, word_to_idx, idx_to_word, bow_idxs = cPickle.load(open(config.data_path,'rb'))
docs_raw = [[[bow_idxs[bow_index]]*int(instance.bow[bow_index]) for bow_index in np.where(instance.bow > 0)[0]] for instance in instances_train]
docs = [[idx for idxs in doc for idx in idxs] for doc in docs_raw][:100]

In [4]:
n_doc = len(docs)
n_vocab = len(np.unique(docs))

In [37]:
n_vocab

899

# initialization 

In [6]:
alpha = 1.
beta = 1.
gam = 0.5
eta = 1.
n_depth = 3
verbose = True

## assign docs to tree

In [29]:
class Topic:
    def __init__(self, idx, parent, depth, n_doc, n_vocab):
        self.idx = idx
        self.parent = parent
        self.children = []
        self.depth = depth
        self.cnt_doc = 0
        self.n_doc = n_doc
        self.n_vocab = n_vocab
        self.cnt_words = np.zeros([n_doc, n_vocab])
        self.verbose = verbose
    
    def sample_child(self, doc_idx, doc, gam, verbose=False):
        p_child_prior = self.get_p_child_prior(gam)
        p_child_likelihood = self.get_p_child_likelihood(doc_idx, doc, eta)
        p_child = np.array(p_child_prior * p_child_likelihood) / np.sum(p_child_prior * p_child_likelihood)
        
        child_index = np.random.multinomial(1, p_child).argmax()
        if verbose: print(self.depth, p_child, p_child_likelihood, child_index)
        
        if child_index < len(self.children):
            child = self.children[child_index]
        else:
            child = self.add_child()
        return child
    
    def get_p_child_prior(self, gam):
        p_child_prior = [child.cnt_doc for child in self.children]
        p_child_prior += [gam]
        return p_child_prior
    
    def get_p_child_likelihood(self, doc_idx, doc, eta):
        if len(self.children) > 0:
            children_cnt_words = np.array([child.cnt_words for child in self.children]) # Children x Document x Vocabulary
            children_cnt_words = np.concatenate([children_cnt_words, np.zeros([1, self.n_doc, self.n_vocab])], 0) # (Children+1) x Document x Vocabulary
        else:
            children_cnt_words = np.zeros([1, self.n_doc, self.n_vocab]) # 1 x Document x Vocabulary
        
        children_cnt_words_sum = np.sum(children_cnt_words, 1) # (Children + 1) x Vocabulary
        children_cnt_words_doc = children_cnt_words_sum - children_cnt_words[:, doc_idx, :] # (Children + 1) x Children x Vocabulary

        logits_prior = gammaln(np.sum(children_cnt_words_doc, -1) + n_vocab*eta) - np.sum(gammaln(children_cnt_words_doc[:, doc] + eta), -1)
        logits_later = gammaln(np.sum(children_cnt_words_sum, -1) + n_vocab*eta) - np.sum(gammaln(children_cnt_words_sum[:, doc] + eta), -1)
        logits_likelihood = logits_prior - logits_later
        p_child_likelihood = np.exp(logits_likelihood)
        return p_child_likelihood
    
    def add_child(self):
        idx = self.idx * 10 + len(self.children)+1
        depth = self.depth+1
        child = Topic(idx=idx, parent=self, depth=depth, n_doc=self.n_doc, n_vocab=self.n_vocab)
        self.children += [child]
        
        return child
    
    def delete_topic(self):
        self.parent.children.remove(self)

## sample doc path

$$p({\bf c}_{m}\hspace{0.5ex}|\hspace{0.5ex}{\bf w}, {\bf c}_{-m}, {\bf z})\propto p({\bf w}_{m}\hspace{0.5ex}|\hspace{0.5ex}{\bf c}, {\bf w}_{-m}, {\bf z})\cdot p({\bf c}_{m}\hspace{0.5ex}|\hspace{0.5ex}{\bf c}_{-m})$$

$$p({\bf w}_{m}\hspace{0.5ex}|\hspace{0.5ex}{\bf c}, {\bf w}_{-m}, {\bf z})=\prod_{\ell=1}^{L}\left(\frac{\Gamma(n_{c_{m,\ell},-m}^{(\cdot)}+W\eta)}{\prod_{w}\Gamma(n_{c_{m,\ell},-m}^{(w)}+\eta)}\frac{\prod_{w}\Gamma(n_{c_{m,\ell},-m}^{(w)}+n_{c_{m,\ell},m}^{(w)}+\eta)}{\Gamma(n_{c_{m,\ell},-m}^{(\cdot)}+n_{c_{m,\ell},m}^{(\cdot)}+W\eta)}\right)$$

In [30]:
def sample_doc_topics(docs, topic_root):
    for doc_idx, doc in enumerate(docs):
        # reset count of docs
        if doc_idx in doc_topics:
            for topic in doc_topics[doc_idx]:
                topic.cnt_doc -= 1
                if topic.cnt_doc == 0: topic.delete_topic()

        topic = topic_root
        topic.cnt_doc += 1
        doc_topics[doc_idx] += [topic]
        for depth in range(n_depth):
            topic = topic.sample_child(doc_idx, doc, gam, verbose)
            topic.cnt_doc += 1
            doc_topics[doc_idx] += [topic]
            
    return doc_topics

## assign words to topics

\begin{align*}
p(z_{i}=j\hspace{0.5ex}|\hspace{0.5ex}{\bf z}_{-i},{\bf w})\propto\frac{n_{-i,j}^{(w_{i})}+\beta}{n_{-i,j}^{(\cdot)}+W\beta}\frac{n_{-i,j}^{(d_{i})}+\alpha}{n_{-i,\cdot}^{(d_{i})}+T\alpha}
\end{align*}

In [31]:
def sample_word_topics(docs, doc_topics):
    for doc_idx, doc in enumerate(docs):
        topics = doc_topics[doc_idx]

        for topic in topics:
            for word_idx in doc:
                topic.cnt_words[doc_idx, word_idx] = 0

        s_docs = np.array([np.sum(topic.cnt_words[doc_idx, :])+alpha for topic in topics]) # L

        for word_idx in doc:
            s_words = np.array([np.sum(topic.cnt_words[:, word_idx])+beta for topic in topics]) # L
            z_words = np.array([np.sum(topic.cnt_words)+n_vocab*beta for topic in topics]) # L

            s_topics = s_docs*s_words/z_words
            p_topics = s_topics/np.sum(s_topics) # L

            word_topic = topics[np.argmax(np.random.multinomial(1, p_topics))]
            word_topic.cnt_words[doc_idx, word_idx] += 1

## run 

In [32]:
def assert_sum_cnt_words(topic_root):
    def recur_cnt_words(topic):
        cnt_words = np.sum(topic.cnt_words)
        for child in topic.children:
            cnt_words += recur_cnt_words(child)
        return cnt_words

    sum_cnt_words = recur_cnt_words(topic_root)
    assert sum_cnt_words == sum([len(doc) for doc in docs])

In [35]:
n_sample = 100
topic_root = Topic(idx=0, parent=None, depth=0, n_doc=n_doc, n_vocab=n_vocab)
doc_topics = defaultdict(list)

for i_sample in range(n_sample):
    # sample path of doc
    doc_topics = sample_doc_topics(docs, topic_root)
    
    # assign words to each topic
    sample_word_topics(docs, doc_topics)

IndexError: index 899 is out of bounds for axis 1 with size 899

In [36]:
%debug

> [0;32m<ipython-input-29-962b9a120e58>[0m(42)[0;36mget_p_child_likelihood[0;34m()[0m
[0;32m     40 [0;31m        [0mchildren_cnt_words_doc[0m [0;34m=[0m [0mchildren_cnt_words_sum[0m [0;34m-[0m [0mchildren_cnt_words[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mdoc_idx[0m[0;34m,[0m [0;34m:[0m[0;34m][0m [0;31m# (Children + 1) x Children x Vocabulary[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m        [0mlogits_prior[0m [0;34m=[0m [0mgammaln[0m[0;34m([0m[0mnp[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mchildren_cnt_words_doc[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m [0;34m+[0m [0mn_vocab[0m[0;34m*[0m[0meta[0m[0;34m)[0m [0;34m-[0m [0mnp[0m[0;34m.[0m[0msum[0m[0;34m([0m[0mgammaln[0m[0;34m([0m[0mchildren_cnt_words_doc[0m[0;34m[[0m[0;34m:[0m[0;34m,[0m [0mdoc[0m[0;34m][0m [0;34m+[0m [0meta[0m[0;34m)[0m[0;34m,[0m [0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m[0

## print 

In [57]:
def print_child_idxs(topic):
    print('  '*topic.depth, topic.idx, ':', [child.idx for child in topic.children], topic.cnt_doc, np.sum(topic.cnt_words))
    for topic in topic.children:
        print_child_idxs(topic)

print_child_idxs(topic_root)

 0 : [1, 2, 3] 1000 25295.0
   1 : [11, 12, 13] 919 23059.0
     11 : [111, 112, 113, 114, 115, 116, 117, 118] 889 22281.0
       111 : [] 450 10917.0
       112 : [] 352 8586.0
       113 : [] 13 324.0
       114 : [] 1 19.0
       115 : [] 16 402.0
       116 : [] 50 1262.0
       117 : [] 5 129.0
       118 : [] 2 58.0
     12 : [121] 18 469.0
       121 : [] 18 453.0
     13 : [131] 12 316.0
       131 : [] 12 329.0
   2 : [21, 22, 23] 80 1959.0
     21 : [211, 212, 213] 73 1819.0
       211 : [] 40 1041.0
       212 : [] 32 816.0
       213 : [] 1 35.0
     22 : [221] 6 150.0
       221 : [] 6 160.0
     23 : [231] 1 30.0
       231 : [] 1 16.0
   3 : [31] 1 25.0
     31 : [311] 1 25.0
       311 : [] 1 25.0


8586.0