In [1]:
import os
import numpy as np
from scipy.special import gammaln
import data_prep
from data_prep import voca
from data_prep import docs

[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/ecoronado/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/ecoronado/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [1]:
# Special class 
class DefaultDict(dict):
    def __init__(self, v):
        self.v = v
        dict.__init__(self)
    def __getitem__(self, k):
        return dict.__getitem__(self, k) if k in self else self.v
    def update(self, d):
        dict.update(self, d)
        return self

In [3]:
# Storing default values for start of alg

# Hyperparameters (concentration parms of DP distributions)
gamma = np.random.gamma(1, 1)
alpha = np.random.gamma(1, 1)
beta = .5

# size of vocabulary 
V = voca.size()
# To see words type voca.vocas

# Number of documents 
M = len(docs)

# Table index for document j
using_t = [[0] for j in range(M)]

# Dish index - 0 means draw a new topic 
k = 0
using_k = [0]


# x is data, t is table index, k is topic index, n is number of terms, m is number of tables

# Vocabulary for each doc-term - this is the input data and doesn't change 
x_ji = docs

# Topics of document and table
k_jt = [np.zeros(1 ,dtype=int) for j in range(M)]

# Number of terms for each table of document
n_jt = [np.zeros(1 ,dtype=int) for j in range(M)]   

# Number of terms for each table and vocabulary of document 
n_jtv = [[None] for j in range(M)]


m = 0
# Number of tables for each topic
m_k = np.ones(1 ,dtype=int)  

# Number of terms for each topic ( + beta * V )
n_k = np.array([beta * V]) 

# Number of terms for each topic and vocabulary ( + beta )
n_kv = [DefaultDict(0)]            

# Table for each document and term (-1 means not-assigned)
t_ji = [np.zeros(len(x_i), dtype=int) - 1 for x_i in docs]

In [4]:
## Helpers ## 

# Function that takes v (term index) and returns a list that represents the distribution of a term across topics -- i.e. each element is the proportion of terms in topic k that are term v 
def calc_f_k(v):
    return [n_kv[v] for n_kv in n_kv]/n_k


# Function that calculates the posterior distribution of tables for doc j / arguments: j - doc index,f_k - distribution of term across topics 

def calc_table_posterior(j, f_k, using_t, n_jt):
    
    # Store list of tables for doc j as using_t
    using_t = using_t[j]
    
    # Number of terms in doc j at each table times disibutrion of terms across topics ## CHECK THIS 
    p_t = n_jt[j][using_t] * f_k[k_jt[j][using_t]]
    
    # Sum of number of tables across topics weighted by f_k + gamma/(vocab size) -- this corresponds with the probability of selecting a new table 
    p_x_ji = np.inner(m_k, f_k) + gamma / V
    
    # Storing probability of new table as first element 
    p_t[0] = p_x_ji * alpha / (gamma + m)

    # Return likelihood over prior 
    return p_t / p_t.sum()


def calc_dish_posterior_w(f_k):
    "calculate dish(topic) posterior when one word is removed"
    
    p_k = (m_k * f_k)[using_k]
    p_k[0] = gamma / V
    
    return p_k / p_k.sum()
    
    
def calc_dish_posterior_t(j, t, n_k, n_jt, n_jtv):
    "calculate dish(topic) posterior when one table is removed"
    k_old = k_jt[j][t]     # it may be zero (means a removed dish)
    
    Vbeta = V * beta
    n_k = n_k.copy()
    n_jt2 = n_jt.copy()[j][t]
    n_k[k_old] -= n_jt2
    n_k = n_k[using_k]
    log_p_k = np.log(m_k[using_k]) + gammaln(n_k) - gammaln(n_k + n_jt2)
    log_p_k_new = np.log(gamma) + gammaln(Vbeta) - gammaln(Vbeta + n_jt2)

    gammaln_beta = gammaln(beta)
    for w, n_jtw in n_jtv[j][t].items():
        assert n_jtw >= 0
        if n_jtw == 0: continue
        n_kw = np.array([n.get(w, beta) for n in n_kv])
        n_kw[k_old] -= n_jtw
        n_kw = n_kw[using_k]
        n_kw[0] = 1 # dummy for logarithm's warning
        if np.any(n_kw <= 0): print(n_kw) # for debug
        log_p_k += gammaln(n_kw + n_jtw) - gammaln(n_kw)
        log_p_k_new += gammaln(beta + n_jtw) - gammaln_beta
        
        
    log_p_k[0] = log_p_k_new
    
    p_k = np.exp(log_p_k - log_p_k.max())
    return p_k / p_k.sum()



In [5]:
### HPLDA Alg ### 


# g = 25 epochs
for g in range(25):
    
# Loop - sampling_t - j is doc index (e.g. first doc is 0), i is term index (0 is first element of global vocabulary voca.vocas)

    # Loop through the data 
    for j, x_i in enumerate(x_ji):
        
        # For each doc, loop through each term
        for i in range(len(x_i)):
            
            ### Reassign table for term i in document j ###
            t = t_ji[j][i]
            if t  > 0:
                k = k_jt[j][t]
                assert k > 0
        
                # decrease counters
                v = x_ji[j][i]
                n_kv[k][v] -= 1
                n_k[k] -= 1
                n_jt[j][t] -= 1
                n_jtv[j][t][v] -= 1
        
                if n_jt[j][t] == 0:
                    
                    # Remove table 
                    
                    # Set topic index at doc j and table t to k
                    k = k_jt[j][t]
                    
                    # Remove t from list of tables being used in doc j
                    using_t[j].remove(t)
                    
                    # Decrease number of tables for topic k by 1
                    m_k[k] -= 1
                    # Decrease number of tables overall (?) by 1
                    m -= 1
                    assert m_k[k] >= 0
                    
                    # If number of tables for topic k is 0 remove topic
                    if m_k[k] == 0:
                        using_k.remove(k)
        
                                    
            # Store term index as v
            v = x_ji[j][i]
            
            # Calculate the distribution of v across the topics -- f_k will be the base distribution for the calc_table_posterior function 
            f_k = calc_f_k(v)
            assert f_k[0] == 0 # f_k[0] is a dummy and will be erased
        
            
            # Calculating the posterior distribution of tables --  p(t_ji=t)
            p_t = calc_table_posterior(j, f_k, using_t, n_jt)
            
            # This just prints some results while the alg runs - blocking out for now     
            # if len(p_t) > 1 and p_t[1] < 0: dump()
                
            # Sample from the posterior and assigned the corresponding table index to t_new (not necessarily a new table - it's a new sample)
            t_new = using_t[j][np.random.multinomial(1, p_t).argmax()]
            
            # If t_new == 0 (i.e. the table is new)
            if t_new == 0:
                
                # Calculate the posterior distribution of topics 
                p_k = calc_dish_posterior_w(f_k)
                
                # Sample from this posterior distribution and assign the corresponding topic index to k_new 
                k_new = using_k[np.random.multinomial(1, p_k).argmax()]
                
                # If k_new == 0 (i.e. the topic is new)
                if k_new == 0:
                    
                    # Add new dish and store as k_new 
                    for k_new, k in enumerate(using_k):
                        if k_new != k: break
                    else:
                        k_new = len(using_k)
                        if k_new >= len(n_kv):
                            n_k = np.resize(n_k, k_new + 1)
                            m_k = np.resize(m_k, k_new + 1)
                            n_kv.append(None)
                        assert k_new == using_k[-1] + 1
                        assert k_new < len(n_kv)
    
                    using_k.insert(k_new, k_new)
                    n_k[k_new] = beta * V
                    m_k[k_new] = 0
                    n_kv[k_new] = DefaultDict(beta)
                    
                assert k_new in using_k
                
                for t_new, t in enumerate(using_t[j]):
                    if t_new != t: break
                else:
                    t_new = len(using_t[j])
                    n_jt[j].resize(t_new+1)
                    k_jt[j].resize(t_new+1)
                    n_jtv[j].append(None)
            
                using_t[j].insert(t_new, t_new)
                n_jt[j][t_new] = 0  # to make sure
                n_jtv[j][t_new] = DefaultDict(0)
            
                k_jt[j][t_new] = k_new
                
                m_k[k_new] += 1
                
                m += 1
            
            assert t_new in using_t[j]
            t_ji[j][i] = t_new
            n_jt[j][t_new] += 1
    
            k_new = k_jt[j][t_new]
            n_k[k_new] += 1
    
            v = x_ji[j][i]
            n_kv[k_new][v] += 1
            n_jtv[j][t_new][v] += 1
            
                
    for j in range(M):
        for t in using_t[j]:
            if t != 0: 
                """sampling k (dish=topic) from posterior"""
    
                #This makes the table leave from its dish and only the table counter decrease. The word counters (n_k and n_kv) stay.
                
                k = k_jt[j][t]
                assert k > 0
                assert m_k[k] > 0
                
                m_k[k] -= 1
                m -= 1
                if m_k[k] == 0:
                    using_k.remove(k)
                    k_jt[j][t] = 0
                #
                    
                # sampling of k
                p_k = calc_dish_posterior_t(j, t, n_k, n_jt, n_jtv)
                
                k_new = using_k[np.random.multinomial(1, p_k).argmax()]
                
                
                
                
                if k_new == 0:
                    # Add new dish  
                    for k_new, k in enumerate(using_k):
                        if k_new != k: break
                    else:
                        k_new = len(using_k)
                        if k_new >= len(n_kv):
                            n_k = np.resize(n_k, k_new + 1)
                            m_k = np.resize(m_k, k_new + 1)
                            n_kv.append(None)
                        assert k_new == using_k[-1] + 1
                        assert k_new < len(n_kv)
                
                    using_k.insert(k_new, k_new)
                    n_k[k_new] = beta * V
                    m_k[k_new] = 0
                    n_kv[k_new] = DefaultDict(beta)
                    
      
                    
                m += 1
                m_k[k_new] += 1
            
                k_old = k_jt[j][t]     # it may be zero (means a removed dish)
                if k_new != k_old:
                    k_jt[j][t] = k_new
            
                    n_jt2 = n_jt.copy()[j][t]
                    if k_old != 0: n_k[k_old] -= n_jt2
                    n_k[k_new] += n_jt2
                    for v, n in n_jtv[j][t].items():
                        if k_old != 0: n_kv[k_old][v] -= n
                        n_kv[k_new][v] += n
        
            

In [12]:
## Sanity Checks ## 

In [6]:
# The sum of the words among all the tables in doc 0 equals the sum of the words in doc 0
sum(n_jt[0]) == len(x_ji[0])


True

In [7]:
# n_jtv[0] (0 indexes first doc) is a list of dictionaries with each dictionary item giving the vocab index at a given table -- dict.keys() gives vocab index and dict.values() gives count 
# of that term at the table 

# This is the sum of words at the tables of doc 0.  It should be the same as n_jt[0][1:]
[j for j in [sum(i.values()) for i in n_jtv[0] if i is not None] if j != 0]

[5, 17, 2, 15, 1, 1]

In [22]:
[j for j in [sum(i.values()) for i in n_jtv[0] if i is not None] if j != 0] == [h for h in n_jt[0] if h != 0]

True

In [9]:
# This is the total number of tables (m_k is the distribution of tables across the 26 topics)
sum(m_k)

3783

In [10]:
# This is also the total number of tables.  n_jt is the number of words at each table of each document.  This should be the same as sum(m_k) but it's off by 1 for some reason
sum([len(g) for g in [[j for j in i if j!=0] for i in n_jt]])

3782

In [14]:












#### Work Space ##### --- this can be deleted 




def calc_dish_posterior_t(j, t):
    "calculate dish(topic) posterior when one table is removed"
    k_old = k_jt[j][t]     # it may be zero (means a removed dish)
    
    Vbeta = V * beta
    n_k = n_k.copy()
    n_jt = n_jt[j][t]
    n_k[k_old] -= n_jt
    n_k = n_k[using_k]
    log_p_k = np.log(m_k[using_k]) + gammaln(n_k) - gammaln(n_k + n_jt)
    log_p_k_new = np.log(gamma) + gammaln(Vbeta) - gammaln(Vbeta + n_jt)

    gammaln_beta = gammaln(beta)
    for w, n_jtw in n_jtv[j][t].items():
        assert n_jtw >= 0
        if n_jtw == 0: continue
        n_kw = np.array([n.get(w, beta) for n in n_kv])
        n_kw[k_old] -= n_jtw
        n_kw = n_kw[using_k]
        n_kw[0] = 1 # dummy for logarithm's warning
        if np.any(n_kw <= 0): print(n_kw) # for debug
        log_p_k += gammaln(n_kw + n_jtw) - gammaln(n_kw)
        log_p_k_new += gammaln(beta + n_jtw) - gammaln_beta
        
        
    log_p_k[0] = log_p_k_new
    
    p_k = np.exp(log_p_k - log_p_k.max())
    return p_k / p_k.sum()





# leave_from_table - arguments: doc index / term index - this function adjusts counts (decrement by 1) for doc j, term i, table t, topic k // table is removed if needed 
def leave_from_table(j, i):
    
    # Store table index of a given document and term as t
    t = t_ji[j][i]
    
    # If t>0 (i.e. reassigning from existing table)
    if t  > 0:
        # Set topic for that table to k
        k = k_jt[j][t]
        assert k > 0

        # Store term index for given doc as v
        v = x_ji[j][i]
        
        # Decrease counters - number of terms assigned to k for vocab v decreased by 1 / number of terms assigned to k decreased by 1
        n_kv[k][v] -= 1
        n_k[k] -= 1
        
        # Decrease counters - number of terms in doc j at table t decreased by 1 / term of terms in doc j at table t for vocab v decreased by 1
        n_jt[j][t] -= 1
        n_jtv[j][t][v] -= 1
        
        # If number of terms assigned to table t is now 0 remove the table 
        if n_jt[j][t] == 0:
            remove_table(j, t)
            
# Remove the table when all terms are gone - arguments: doc index / table index
def remove_table(j, t):
    
    # Set topic index at doc j and table t to k
    k = k_jt[j][t]
    
    # Remove t from list of tables being used in doc j
    using_t[j].remove(t)
    
    # Decrease number of tables for topic k by 1
    m_k[k] -= 1
    # Decrease number of tables overall (?) by 1
    m -= 1
    assert m_k[k] >= 0
    
    # If number of tables for topic k is 0 remove topic
    if m_k[k] == 0:
        using_k.remove(k)
        
        
        
# Assign guest x_ji to a new table and draw topic (dish) of the table
#def add_new_table(j, k_new, using_t, using_k, n_jt, k_jt, n_jtv, m_k, m):
#    assert k_new in using_k
#    for t_new, t in enumerate(using_t[j]):
#        if t_new != t: break
#    else:
#        t_new = len(using_t[j])
#        n_jt[j].resize(t_new+1)
#        k_jt[j].resize(t_new+1)
#        n_jtv[j].append(None)
#
#    using_t[j].insert(t_new, t_new)
#    n_jt[j][t_new] = 0  # to make sure
#    n_jtv[j][t_new] = DefaultDict(0)
#
#    k_jt[j][t_new] = k_new
#    m_k[k_new] += 1
#    m += 1
#
#    return t_new


#def seat_at_table(j, i, t_new):
#    
#    assert t_new in using_t[j]
#    
#    t_ji[j][i] = t_new
#    n_jt[j][t_new] += 1
#
#    k_new = k_jt[j][t_new]
#    n_k[k_new] += 1
#
#    v = x_ji[j][i]
#    n_kv[k_new][v] += 1
#    n_jtv[j][t_new][v] += 1




## These are functions that still need to be annotated 

def seat_at_table(j, i, t_new):
    
    assert t_new in using_t[j]
    
    t_ji[j][i] = t_new
    n_jt[j][t_new] += 1

    k_new = k_jt[j][t_new]
    n_k[k_new] += 1

    v = x_ji[j][i]
    n_kv[k_new][v] += 1
    n_jtv[j][t_new][v] += 1

# Assign guest x_ji to a new table and draw topic (dish) of the table
def add_new_table(j, k_new):
    assert k_new in using_k
    for t_new, t in enumerate(using_t[j]):
        if t_new != t: break
    else:
        t_new = len(using_t[j])
        n_jt[j].resize(t_new+1)
        k_jt[j].resize(t_new+1)
        n_jtv[j].append(None)

    using_t[j].insert(t_new, t_new)
    n_jt[j][t_new] = 0  # to make sure
    n_jtv[j][t_new] = DefaultDict(0)

    k_jt[j][t_new] = k_new
    m_k[k_new] += 1
    m += 1

    return t_new






def sampling_k(j, t):
    """sampling k (dish=topic) from posterior"""
    leave_from_dish(j, t)

    # sampling of k
    p_k = calc_dish_posterior_t(j, t)
    k_new = using_k[np.random.multinomial(1, p_k).argmax()]
    if k_new == 0:
        k_new = add_new_dish()

    seat_at_dish(j, t, k_new)
    
    
    

def leave_from_dish(self, j, t):
    """
    This makes the table leave from its dish and only the table counter decrease.
    The word counters (n_k and n_kv) stay.
    """
    k = k_jt[j][t]
    assert k > 0
    assert m_k[k] > 0
    
    m_k[k] -= 1
    m -= 1
    if m_k[k] == 0:
        using_k.remove(k)
        k_jt[j][t] = 0

        
        
def calc_dish_posterior_t(j, t, n_k, n_jt, k_jt, n_jtv, n_jtw, n_kw, V, beta, using_k):
    "calculate dish(topic) posterior when one table is removed"
    k_old = k_jt[j][t]     # it may be zero (means a removed dish)
    
    Vbeta = V * beta
    n_k = n_k.copy()
    n_jt = n_jt[j][t]
    n_k[k_old] -= n_jt
    n_k = n_k[using_k]
    log_p_k = np.log(m_k[using_k]) + gammaln(n_k) - gammaln(n_k + n_jt)
    log_p_k_new = np.log(gamma) + gammaln(Vbeta) - gammaln(Vbeta + n_jt)
    
    gammaln_beta = gammaln(beta)
    
    for w, n_jtw in n_jtv[j][t].items():
        assert n_jtw >= 0
        if n_jtw == 0: continue
        n_kw = np.array([n.get(w, beta) for n in n_kv])
        n_kw[k_old] -= n_jtw
        n_kw = n_kw[using_k]
        n_kw[0] = 1 # dummy for logarithm's warning
        if np.any(n_kw <= 0): print(n_kw) # for debug
        log_p_k += gammaln(n_kw + n_jtw) - gammaln(n_kw)
        log_p_k_new += gammaln(beta + n_jtw) - gammaln_beta
        
    log_p_k[0] = log_p_k_new
    
    p_k = np.exp(log_p_k - log_p_k.max())
    return p_k / p_k.sum()


def seat_at_dish(j, t, k_new):
    m += 1
    m_k[k_new] += 1

    k_old = k_jt[j][t]     # it may be zero (means a removed dish)
    if k_new != k_old:
        k_jt[j][t] = k_new

        n_jt = n_jt[j][t]
        if k_old != 0: n_k[k_old] -= n_jt
        n_k[k_new] += n_jt
        for v, n in n_jtv[j][t].items():
            if k_old != 0: n_kv[k_old][v] -= n
            n_kv[k_new][v] += n


def add_new_dish():
    "This is commonly used by sampling_t and sampling_k."
    for k_new, k in enumerate(using_k):
        if k_new != k: break
    else:
        k_new = len(using_k)
        if k_new >= len(n_kv):
            n_k = numpy.resize(n_k, k_new + 1)
            m_k = numpy.resize(m_k, k_new + 1)
            n_kv.append(None)
        assert k_new == using_k[-1] + 1
        assert k_new < len(n_kv)

    using_k.insert(k_new, k_new)
    n_k[k_new] = beta * V
    m_k[k_new] = 0
    n_kv[k_new] = DefaultDict(beta)
    return k_new