## HDP code

In [4]:
import numpy as np
from scipy.special import gammaln
import data_preproc
from data_preproc import data_preproc

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/ecoronado/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/ecoronado/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [123]:
vocab, docs = data_preproc("../tm_test_data.csv") # load vocab and docs

In [274]:
###########################################
########### SAMPLING T FUNCTIONS ##########
##########################################

def sampling_t(doc_j, i, word, n_kv, m_k, k_idx): 
    '''For each word in document j (doc_j), sample for posterior distribution of t and update
       table and topic assignments, as well as other count structures within jt_info, n_jtw, m_k, and n_kv
       Output: updated doc_j, m_k, n_kv, and topic idx (k_idx)'''
    
    t_idx = doc_j['w_tbl_idx'][i]
    if t_idx + 1 > doc_j['jt_info'].shape[1]:
        print(t_idx, doc_j)
    k_jt = doc_j['jt_info'][1, t_idx]

    ### Remove word if assigned to table (i.e. -x_ji)
    if t_idx > 0: 
        assert k_jt > 0
        
        doc_j, m_k, k_idx, n_kv = remove_x_ji(doc_j, t_idx, k_jt, m_k, k_idx, word)
        
    #### Sampling t ####
    fk = fk_m_xji(n_kv, word) 
    
    # Un-normalized posterior pvals 
    t_post = post_pvals_t(doc_j, k_jt, fk, m_k, alpha, gamma)
    t_post /=  t_post.sum() 
    
    
    # Get most likely table selection
    t_samp_idx = np.random.multinomial(1, t_post).argmax()
    new_t = doc_j['jt_info'][0, t_samp_idx]  

    ## New table is selected
    if new_t == 0:

        ### Sampling k when t is NEW ###
        kt_post = post_pvals_k_new_t(m_k, fk, k_idx, gamma, V)
        kt_post /= kt_post.sum()

        
        # Select most likely topic for new table
        kt_samp_idx = np.random.multinomial(1, kt_post).argmax()
        new_k = k_idx[kt_samp_idx]

        ## New topic selected
        if new_k == 0:
            
            # Create new topic
            new_k, n_kv, m_k, k_idx = new_topic(n_kv, m_k, k_idx, beta, V)
        
        m_k[new_k] +=1 #add to table cnt for topic k
        
        # Add new table
        new_t, doc_j = new_table(new_k, m_k, doc_j, word)

    # Assign word to table
    doc_j, n_kv = assign_to_table(doc_j, new_t, n_kv, word)
    
    
    return doc_j, n_kv, m_k, k_idx



def remove_x_ji(doc_j, t_idx, k_jt, m_k, k_idx, word):
    '''Remove word if assigned to table (i.e. -x_ji), calls on remove_table helper function
       Inputs: table idx, topic for table t, word
       Outputs: updated n_kv, plus additional 
                 updates on doc_j, m_k (tables in topic k), and k_idx (topics) from remove_table fcn '''
    
    doc_j['n_jtw'][t_idx][word] -=1  # remove from dictionary table cnt n_jtk
    if doc_j['n_jtw'][t_idx][word] == 0:
        del doc_j['n_jtw'][t_idx][word]
    
    doc_j['jt_info'][2, t_idx] -= 1 # remove from table cnt n_jt
    n_kv[word, k_jt] -=1 # remove from topic count

    # if table is empty, remove table
    if doc_j['jt_info'][2, t_idx] == 0: 
        doc_j, m_k, k_idx = remove_table(t_idx, doc_j, m_k, k_idx)
    
    return doc_j, m_k, k_idx, n_kv

                      
def remove_table(t_idx, doc_j, m_k, k_idx):
    '''Empty tables (i.e. n_jt == 0) are removed
       Inputs: table idx, doc_j and m_k (tables in topic k)
       Outputs: Updated doc_j, m_k, k_idx '''
    
    k_jt = doc_j['jt_info'][1, t_idx]
    
    # Delete table 
    #doc_j['jt_info'] = np.delete(doc_j['jt_info'],t_idx, axis =1) # remove table, i.e. del column
    doc_j['jt_info'][0,t_idx] = 0
    m_k[k_jt] -= 1
    
    if m_k[k_jt] == 0: 
        k_idx.remove(k_jt) #if no more tables with topic k, remove topic

    return doc_j, m_k, k_idx



def fk_m_xji(n_kv, word):
    '''Conditional density of x_ji given k and all data items except x_ji'''
    return n_kv[word,:] / n_kv.sum(axis=0)



def post_pvals_t(doc_j, k_jt, fk, m_k, alpha, gamma):
    '''Generate posterior pvals for both selecting a new or existing table'''
    # if t is NOT NEW
    n_jt = doc_j['jt_info'][2,:] # get counts across tables
    t_post = n_jt*fk[k_jt]

    # if t is NEW
    p_xji=0
    for k in range(len(k_idx)): # compute p_xji based on paper
        p_xji += m_k[k] * fk[k]

    p_xji = p_xji + (gamma / V) 
    t_post[0] = (alpha * p_xji)/ (sum(m_k) + gamma) # t if new store as first
    
    return t_post


def post_pvals_k_new_t(m_k, fk, k_idx, gamma, V):
    '''If new table selected, generate posterior pvals for selecting a new or existing topic'''
    kt_post = (m_k*fk)[k_idx] # existing topic
    kt_post[0] = gamma /V # new topic
    
    return kt_post




def new_topic(n_kv, m_k, k_idx, beta, V):
    '''If new topic selected, get new topic k and extend structures k_idx (topic idx), n_kv (word-topic matrix), 
       m_k (tables per topic) for later updates. 
       Output: new topic and extended structures'''

    resize = False
    # Create new topic
    for idx, k in enumerate(k_idx):
        if idx != k: 
            break
        else:
            idx = len(k_idx)
            if idx >= n_kv.shape[1]:
                resize = True
            try:
                assert idx == k_idx[-1] + 1
            except AssertionError as e:
                e.args += (idx, k_idx)
                raise
            
    if resize:
        n_kv = np.c_[n_kv, np.zeros((V, 1), dtype=int)]
        m_k = np.r_[m_k, 0]
        assert idx < n_kv.shape[1]
    
    # Append new topic to list of topics, add column to word-topic matrix, extend table-topic array
    k_idx.insert(idx, idx)
    n_kv[:, idx] = np.ones(V, dtype = int) * beta
    m_k[idx] = 0
    
    assert idx in k_idx

    return idx, n_kv, m_k, k_idx


def new_table(new_k, m_k, doc_j, word):
    '''If new table selected, get new table idx and extend structures doc_j jt_info and n_jtw for 
       later updates
       Output: new table and extended structures'''
    
    
    resize = False
    
    for t_idx, t in enumerate(doc_j['jt_info'][0,:]):
        if t_idx != t: 
            break
        else:
            t_idx = doc_j['jt_info'].shape[1]
            resize = True

    
    if resize:
        doc_j['n_jtw'].append({word:0})
        doc_j['jt_info'] = np.c_[doc_j['jt_info'], np.zeros((3,1), dtype=int)]
 
    # Add column to doc's 'jt_info' array, set topic of new table,extend discretized word cnt dict
    # to allocate word in new table
    doc_j['jt_info'][0, t_idx] = t_idx
    doc_j['jt_info'][1, t_idx] = new_k
    doc_j['n_jtw'][t_idx][word] = 0
    
    
    return t_idx, doc_j



def assign_to_table(doc_j, new_t, n_kv, word):
    '''Assign word to table new_t with topic new_k in doc_j, add counts to overall table count,  
       word-topic matrix and discretized table word counts
       Outputs: updated doc_j and n_kv'''
    
    assert new_t in doc_j['jt_info'][0,:]
    # Get word-table assignment idx, add 1 to discretized word cnt dictionary for that table
    doc_j['w_tbl_idx'][i] = new_t
    doc_j['jt_info'][2, new_t] += 1
    
    # Get topic of table (either new or old)
    new_k = doc_j['jt_info'][1, new_t]
    
     # Add 1 for word in word-topic count matrix
    n_kv[word, new_k] += 1
    
    # Seat at table, assign corresponding topic, add 1 to overall table count
    doc_j['jt_info'][1, new_t] =  new_k
    
    
    doc_j['n_jtw'][new_t].update({word: 1})
    
    
   
    
    return doc_j, n_kv




###########################################
########### SAMPLING K FUNCTIONS ##########
##########################################

def sampling_k(doc_j, tbl, n_kv, m_k, k_idx):
    '''For each TABLE in document j (doc_j), sample for posterior distribution of k and update
       table and topic assignments, as well as other count structures within jt_info, n_jtw, m_k, and n_kv
       Output: updated doc_j, m_k, n_kv, and topic idx (k_idx)'''
    
    #### START of Sampling k loop through tables, (skip first index always, 0 = dummy idx) ####
    if tbl != 0: 

        # Get topic k, remove all components from table t associated with topic k
        doc_j, m_k, k_idx = remove_Xvec_ji(doc_j, tbl, m_k, k_idx)

        # Samples posterior p-vals K
        post_k = post_pvals_k(doc_j, tbl, n_kv, m_k, k_idx, V, beta)
        post_k /= post_k.sum()
        
        # Select most likely topic for table
        k_samp_idx = np.random.multinomial(1, post_k).argmax()
        
        new_k = k_idx[k_samp_idx]

        ## New topic selected
        if new_k == 0:

            # Create new topic
            new_k, n_kv, m_k, k_idx = new_topic(n_kv, m_k, k_idx, beta, V)
        
        # Add table to topic k count
        m_k[new_k] += 1
        
        doc_j, n_kv = rearranging_k_counts(doc_j, tbl, new_k, n_kv)
            

    return doc_j, n_kv, m_k, k_idx 



def remove_Xvec_ji(doc_j, tbl, m_k, k_idx):
    '''Remove table from topic k (i.e. related removing all components associated to table t later)
       If table becomes empty, remove topic'''
    
    # Get topic k, remove all components from table t associated with topic k
    k_jt = doc_j['jt_info'][1, tbl]
    m_k[k_jt] -= 1 # remove from table-topic vector

    if m_k[k_jt] == 0:
        k_idx.remove(k_jt) # if no more tables with topic k, remove topic k and set table's topic to 0
        doc_j['jt_info'][1, tbl] = 0
    
    return doc_j, m_k, k_idx


def post_pvals_k(doc_j, tbl, n_kv, m_k, k_idx, V, beta):
    '''Compute explicit posterior pvals distribution based on dirichlet-multinomial form'''
    
    # Topic k of table t
    k_jt = doc_j['jt_info'][1, tbl]

    # Remove all counts associated with topic k in table t, from overall topic counts (n_k)
    
    n_kv = n_kv.copy() #### NOTE: fix, remind me to never disregard what Cliburn says in class 20 times about slicing
    n_k = n_kv.sum(axis = 0)
    n_jt = doc_j['jt_info'][2, tbl]
    n_k[k_jt] -= n_jt
    n_k = n_k[k_idx]

    # Initialized k posterior in log-form for simplicity, this computes f_k^{-X_ji} 
    # has Dirichlet-Multinomial form
    log_post_k = np.log(m_k[k_idx]) + gammaln(n_k) - gammaln(n_k + n_jt)
    log_post_k_new = np.log(gamma) + gammaln(V*beta) - gammaln((V*beta) + n_jt)

    # Remove individual word counts associated with topic k
    # add their contributions to k posterior
    for w_key, w_cnt in doc_j['n_jtw'][tbl].items():

        assert w_cnt >= 0
        if w_cnt == 0: # if word count is 0 skip
            continue

        # For word w, get counts across topics - if zero set as beta
        w_cnt_k = n_kv[w_key, :]
        w_cnt_k[w_cnt_k == 0] = beta
        
        if np.any(w_cnt_k <= 0): print("pre- check", j, tbl, k_jt, w_key, w_cnt_k, w_cnt)

        # For specific topic k, remove count from associated table t
        w_cnt_k[k_jt] -= w_cnt
        w_cnt_k = w_cnt_k[k_idx]

        w_cnt_k[0] = 1
        if np.any(w_cnt_k <= 0): print("check", j, tbl, k_jt, w_key, w_cnt_k, w_cnt)

        # Add contributions
        log_post_k += gammaln(w_cnt_k  + w_cnt) - gammaln(w_cnt_k)
        log_post_k_new += gammaln(beta + w_cnt) - gammaln(beta)


    # set K new
    log_post_k[0] = log_post_k_new

    # Bring back to non-log realm, normalize k-posterior 
    post_k = np.exp(log_post_k - log_post_k.max())

    return post_k


def rearranging_k_counts(doc_j, tbl, new_k, n_kv):
    '''For sampled k, rearrange counts for topics accordingly (i.e. if a new k was selected, subtract
       from previous k and add to new k in word-topic matrix)'''
    # If new topic for table t is selected, set topic to new topic
    k_jt = doc_j['jt_info'][1, tbl]
    if new_k != k_jt: 
        doc_j['jt_info'][1, tbl] = new_k

        # On word-topic matrix, move counts from old topic to new topic
        for k, cnt in doc_j['n_jtw'][tbl].items():
            if k_jt != 0: 
                n_kv[k, k_jt] -= cnt

            n_kv[k, new_k] += cnt
            
    return doc_j, n_kv
   
    




In [None]:
#############################
#### Initialize params ######
#############################

# Hyper params
beta = 0.5 # word concentration (LDA)
alpha = np.random.gamma(1, 1) # GP hyperparam
gamma = np.random.gamma(1, 1) # Base GP hyperparam

V = vocab.shape[0] # length of vocabulary

D = len(docs) # numb docs


#### Storing structures

# dictionary per doc j, 
# has a 'jt_info' 3 x k_jt array where 
# 1st row = table idx (using_t), 2nd row = topic idx (k_jt), 3rd row = table cnt (n_jt)
# 
# 'w_tbl_idx' is a vector storing word-table assignments (t_ji)
# 
# 'n_jtw' is cnt dict discretized by words within each topic-table (circle) 

docs_dict = {j:{'jt_info':np.zeros((3,1), dtype=int, order='F'), 
               'w_tbl_idx': np.zeros(len(docs[j]), dtype=int) -1, 
               'n_jtw':[{beta:beta}]} for j in range(D)}



# A V+1 x k matrix, each row is a word so column sums gives us n_k
n_kv = np.ones((V, 1)) * beta

m_k = np.ones(1, dtype=int) # 1 x k matrix storing tables per topic

k_idx = [0] # list storing topics

x_ji = docs # data

In [284]:

#############################
#### HDP ######
#############################
for z in range(5):
    for j, x_i in enumerate(x_ji):

        doc_j = docs_dict[j]

        for i, w in enumerate(x_i):

            doc_j, n_kv, m_k, k_idx = sampling_t(doc_j, i, w, n_kv, m_k, k_idx)

        docs_dict[j] = doc_j


    for j in range(D):
        doc_j = docs_dict[j]
        for tbl in doc_j['jt_info'][0, :]:

            doc_j, n_kv, m_k, k_idx = sampling_k(doc_j, tbl, n_kv, m_k, k_idx)

        docs_dict[j] = doc_j



In [285]:
def clean_up(dicts):
    
    for d in dicts.values():

        idxs = np.argwhere(d['jt_info'][2, :] == 0).ravel()[1:]

        d['jt_info'] = np.delete(d['jt_info'], idxs, axis =1)

        for i in sorted(idxs, reverse = True):  
            d['n_jtw'].pop(i)
    
    return dicts

In [286]:
docs_dict = clean_up(docs_dict) # final dictionary