* Use train word2vec
* Add the text library
* Implement Gaussian LDA
* Add supervision

# Load the necessary packages

In [143]:
from gensim.models.word2vec import Word2Vec
import gensim.downloader as api
import nltk
import pandas as pd
import re
import numpy as np
import scipy as sp

# Get the word embeddings

## Train the model on a sample dataset

In [144]:
# this loads the text8 dataset
corpus = api.load('text8')

# train a Word2Vec model
model_text8 = Word2Vec(corpus,iter=10,size=150, window=10, min_count=1, workers=10)  # train a model from the corpus

## Update the model on the text messages

* Note: May want to train on additional subject matter documents

### Load the text message library

In [145]:
txtLibrary = pd.read_excel('/Users/Nikki/Dropbox/UNC/Causal NLP/Reback_TxtLibrary/Reback_Project Tech Support Text Message Library_NF.xlsx',
             sheet_name = 1, skiprows = 23, names = ['msgID', 'txt'], index_col = 'msgID')

# Housekeeping
txtLibrary = txtLibrary.dropna(axis = 0) # Remove rows that are empty
# Remove the texts about follow up appointments (pre/post texts)
# txtLibrary = txtLibrary[slice('A1a001', 'H3a120')] # Full library
# txtLibrary = txtLibrary[slice('A1a001', 'B3b017')] # First two topics

### Function for text pre-processing

In [146]:
cList = {
  "ain't": "am not",
  "aren't": "are not",
  "can't": "cannot",
  "can't've": "cannot have",
  "'cause": "because",
  "could've": "could have",
  "couldn't": "could not",
  "couldn't've": "could not have",
  "didn't": "did not",
  "doesn't": "does not",
  "don't": "do not",
  "hadn't": "had not",
  "hadn't've": "had not have",
  "hasn't": "has not",
  "haven't": "have not",
  "he'd": "he would",
  "he'd've": "he would have",
  "he'll": "he will",
  "he'll've": "he will have",
  "he's": "he is",
  "how'd": "how did",
  "how'd'y": "how do you",
  "how'll": "how will",
  "how's": "how is",
  "i'd": "i would",
  "i'd've": "i would have",
  "i'll": "i will",
  "i'll've": "i will have",
  "i'm": "i am",
  "i've": "i have",
  "isn't": "is not",
  "it'd": "it had",
  "it'd've": "it would have",
  "it'll": "it will",
  "it'll've": "it will have",
  "it's": "it is",
  "let's": "let us",
  "ma'am": "madam",
  "mayn't": "may not",
  "might've": "might have",
  "mightn't": "might not",
  "mightn't've": "might not have",
  "must've": "must have",
  "mustn't": "must not",
  "mustn't've": "must not have",
  "needn't": "need not",
  "needn't've": "need not have",
  "o'clock": "of the clock",
  "oughtn't": "ought not",
  "oughtn't've": "ought not have",
  "shan't": "shall not",
  "sha'n't": "shall not",
  "shan't've": "shall not have",
  "she'd": "she would",
  "she'd've": "she would have",
  "she'll": "she will",
  "she'll've": "she will have",
  "she's": "she is",
  "should've": "should have",
  "shouldn't": "should not",
  "shouldn't've": "should not have",
  "so've": "so have",
  "so's": "so is",
  "that'd": "that would",
  "that'd've": "that would have",
  "that's": "that is",
  "there'd": "there had",
  "there'd've": "there would have",
  "there's": "there is",
  "they'd": "they would",
  "they'd've": "they would have",
  "they'll": "they will",
  "they'll've": "they will have",
  "they're": "they are",
  "they've": "they have",
  "to've": "to have",
  "wasn't": "was not",
  "we'd": "we had",
  "we'd've": "we would have",
  "we'll": "we will",
  "we'll've": "we will have",
  "we're": "we are",
  "we've": "we have",
  "weren't": "were not",
  "what'll": "what will",
  "what'll've": "what will have",
  "what're": "what are",
  "what's": "what is",
  "what've": "what have",
  "when's": "when is",
  "when've": "when have",
  "where'd": "where did",
  "where's": "where is",
  "where've": "where have",
  "who'll": "who will",
  "who'll've": "who will have",
  "who's": "who is",
  "who've": "who have",
  "why's": "why is",
  "why've": "why have",
  "will've": "will have",
  "won't": "will not",
  "won't've": "will not have",
  "would've": "would have",
  "wouldn't": "would not",
  "wouldn't've": "would not have",
  "y'all": "you all",
  "y'alls": "you alls",
  "y'all'd": "you all would",
  "y'all'd've": "you all would have",
  "y'all're": "you all are",
  "y'all've": "you all have",
  "you'd": "you had",
  "you'd've": "you would have",
  "you'll": "you you will",
  "you'll've": "you you will have",
  "you're": "you are",
  "you've": "you have"
}

c_re = re.compile('(%s)' % '|'.join(cList.keys()))

def expandContractions(text, c_re=c_re):
    def replace(match):
        return cList[match.group(0)]
    return c_re.sub(replace, text)

### Text pre-processing

In [147]:
# Make all the words lower case
txtLibrary.txt = txtLibrary.txt.str.lower()

# Change phone numbers 
phoneNumPattern = '\d{3}-\d{3}-\d{4}'
phoneNumReplacement = 'xxx-xxx-xxxx'
txtLibrary.txt = txtLibrary.txt.str.replace(phoneNumPattern, phoneNumReplacement)

# Remove numbers
txtLibrary.txt = txtLibrary.txt.str.replace(r'\d+', '')

# Expand contractions
txtLibrary.txt = txtLibrary.txt.apply(expandContractions)

# Remove punctuation
txtLibrary.txt = txtLibrary.txt.str.replace(r'\.|\?|\!|,', '')

# Remove white space
txtLibrary.txt = txtLibrary.txt.str.strip()

# Tokenization
txtLibraryList = txtLibrary.txt.tolist()
txtLibraryTokenList = [nltk.tokenize.word_tokenize(w) for w in txtLibraryList]

### Train the existing model on new terms

In [148]:
# train existing model on new terms
model_text8.build_vocab(txtLibraryTokenList, update = True)

model_text8.train(txtLibraryTokenList, total_examples=model_text8.corpus_count, epochs=model_text8.epochs)

(36929, 54180)

# Prepare the data for Gaussian LDA

## Prepare the documents

In [149]:
# Each element in docVec is the document label for each token in txtLibraryTokenList
docVec = [i for i in range(1,len(txtLibraryTokenList)+1) for j in txtLibraryTokenList[i-1]]
docVec = np.array(docVec)

# Total number of documents 
D = len(txtLibraryTokenList)

# Each element in posInDoc is the word position for each token in each document in txtLibraryTokenList
posInDoc = [j+1 for i in range(len(txtLibraryTokenList)) for j in range(len(txtLibraryTokenList[i]))]
posInDoc = np.array(posInDoc)

# Note: docVec and posInDoc should be the same length

# Number of tokens in each document
n_d = [len(i) for i in txtLibraryTokenList]

# Number of tokens in analysis
Ntot = sum(n_d)

# Replace each word with its vector
wordVectors = [model_text8.wv.__getitem__(j) for i in range(1,len(txtLibraryTokenList)+1) for j in txtLibraryTokenList[i-1]]
wordVectors = np.array(wordVectors)

# Gaussian LDA

## Hyperparameters, initialization, and simulation parameters

### Hyperparameters

In [150]:
mu0 = [0 for i in wordVectors[2]]
kappa = 1
Psi = np.identity(len(wordVectors[2]))
nu_k = 1
K = 3
nu = [nu_k for i in range(K)]
alpha = [1 for i in range(K)]

### Initialization

In [151]:
# Sigma for each topic
wordVecLength = len(wordVectors[1])
Sigma_k = np.identity(wordVecLength)
Sigma = [Sigma_k for i in range(K)]

# mu for each topic
mu_k = [0 for i in wordVectors[2]]
mu = [mu_k for i in range(K)]
mu = np.asarray(mu)

# Initialize the topics for each word in each document
topics = np.random.choice(range(1, K+1), size  = len(docVec), replace = True)

### Simulation parameters

In [152]:
# Number of MCMC samples
L = 500

## Gibbs sampling

### Functions for Gibbs sampling

In [153]:
%run -i 'multivariatet.py'

In [154]:
# Places to save the results
mu_chain = [mu]
Sigma_chain = [np.array(Sigma)]
Z_chain = [topics]


In [None]:
for l in range(L):
    
    # Update the Z_chain
    ## Z_new holds the updates, then we'll append it to the end of Z_chain
    ##Z_new = Z_chain[l].copy()
    Z_new = np.empty(shape = (Ntot))
    
    for d in range(1, D+1):
        for wordNum in range(1, n_d[d-1]+1): 
            
            probs = [0 for i in range(K)]
            
            
            for k in range(1, K+1):
                
                
                # First part of the probability
                a = sum(Z_new[docVec == d] == k) + alpha[k-1]
                
                # Second part of the probability
                b = multivariate_t.pdf(wordVectors[(docVec == d) & (posInDoc == wordNum)], 
                                      mean = mu_chain[l][k-1, :],
                                      shape = Sigma_chain[l][k-1, :],
                                      df = nu[k-1] -1 + wordVecLength)
                
                # Save the numerators
                probs[k-1] = a*b
            
            # Draw a new topic 
            ## Normalize the probabilities
            probs_norm = probs/sum(probs)
                
            ## Update Z_new    
            Z_new[(docVec == d) & (posInDoc == wordNum)] = np.random.choice(range(1, K+1), size = 1, p = probs_norm.flatten())
    
    # Add Z_new to the Z_chain
    Z_chain.append(Z_new)
    
    # Update mu and Sigma
#     Sigma_new = Sigma_chain[l]
#     mu_new = mu_chain[l]
    Sigma_new = np.empty(shape = (K, wordVecLength, wordVecLength))
    mu_new = np.empty(shape = (K, wordVecLength))
    
    for k in range(1, K+1):
        nWordsInTopic = sum(Z_new[Z_new == k])
        
        # Update the topic mean vector
        Lambda_n = (1/(1 + nWordsInTopic)) * Sigma_chain[l][k-1, :]
        mu_n = (1/(1 + nWordsInTopic))*(mu0 + nWordsInTopic* np.mean(wordVectors[Z_new == k]))
        mu_new[k-1, :] = np.random.multivariate_normal(mu_n, Lambda_n)
#         np.append(mu_new, np.random.multivariate_normal(mu_n, Lambda_n))
        
        
        # Update the Sigma for the topics
        resid = wordVectors[Z_new == k]- mu_new[k-1, :]
        S_mu = np.dot(np.transpose(resid), resid)
#         wish_shape = np.linalg.inv(Psi + S_mu)
        wish_shape = Psi + S_mu
        Sigma_new[k-1, :, :] = sp.stats.invwishart.rvs(df = nu[k-1] + nWordsInTopic, scale = wish_shape)
    
    # Update the chains
    Sigma_chain.append(Sigma_new)
    mu_chain.append(mu_new)


    

               
            
          

In [None]:
pd.DataFrame(Z_chain).to_csv('Z_chain.csv')
pd.DataFrame(docVec).to_csv('docVec.csv')


In [59]:
# multivariate_t.pdf(np.array([1, 2, 1]), mean=None, shape=1, df=1)

In [33]:
print(Z_chain[0].shape)

(920,)


In [125]:
# print(mu_chain[100][0])
# print(np.asarray(mu_chain[100][0]))

model_text8.wv.similar_by_vector(mu_chain[500][0], topn=5)

[('preparation', 0.41320061683654785),
 ('prepare', 0.4047275185585022),
 ('composting', 0.4004445970058441),
 ('require', 0.3974263668060303),
 ('repairs', 0.38611549139022827)]

In [126]:
print(Sigma_chain)

[array([[[1., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 1.]],

       [[1., 0., 0., ..., 0., 0., 0.],
        [0., 1., 0., ..., 0., 0., 0.],
        [0., 0., 1., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 1., 0., 0.],
        [0., 0., 0., ..., 0., 1., 0.],
        [0., 0., 0., ..., 0., 0., 1.]]]), array([[[ 6.61843911e+00, -2.63399664e+00, -2.84222518e-01, ...,
          6.98476246e-01,  1.66459367e+00,  2.29355868e-03],
        [-2.63399664e+00,  1.15579204e+01, -1.72435212e+00, ...,
         -1.82261501e+00, -3.20342667e+00,  5.78719805e-01],
        [-2.84222518e-01, -1.72435212e+00,  3.85503319e+00, ...,
          8.08131795e-01,  1.17321419e+00,  6.35884834e-01],
        ...,
        [ 6.98476246e-01, -1.82261501e+00,  8.08131795e-01, ...,
          4.32242391e+00,  2.76792873e+00,  

In [25]:
Z_chain[9].shape

(920,)

In [142]:
print(topics)

[3 3 1 2 2 3 1 2 1 1 1 2 3 1 2 2 2 3 1 3 2 3 1 3 2 3 2 3 1 1 2 3 3 3 2 2 2
 1 3 3 1 2 1 2 1 1 3 1 1 3 1 1 2 3 1 3 3 3 3 2 3 2 2 2 2 2 3 1 1 2 3 2 1 1
 3 2 3 2 2 3 1 3 1 3 1 3 1 1 1 1 1 1 2 1 1 1 1 2 1 1 2 3 2 1 1 3 3 3 3 3 2
 3 2 2 2 1 1 3 1 3 3 2 1 3 2 1 3 2 2 3 2 2 2 1 1 3 3 3 3 2 2 2 3 3 2 3 3 1
 1 1 3 2 2 1 3 2 1 1 1 3 2 1 3 2 3 3 1 1 1 1 2 3 1 3 1 2 2 1 3 2 1 3 1 2 1
 2 2 2 3 1 1 3 3 3 3 1 2 2 3 2 1 2 2 1 1 2 2 2 1 1 1 3 1 3 3 3 1 3 1 1 2 1
 3 2 3 1 1 2 2 3 3 3 2 2 2 1 3 1 1 1 3 1 3 2 2 3 1 1 2 1 3 3 1 3 2 1 1 2 2
 3 1 3 1 1 1 1 3 1 3 1 3 3 1 1 3 2 3 2 2 3 1 1 2 3 1 3 1 1 2 3 1 3 3 2 2 2
 2 2 2 3 3 1 3 1 1 1 3 1 1 2 3 1 1 2 1 1 2 2 3 3 1 3 2 3 3 3 1 1 2 3 2 3 1
 3 2 3 2 2 1 1 2 2 2 2 3 2 2 1 1 1 3 1 3 1 3 1 1 3 3 1 1 2 1 2 2 2 1 3 2 1
 1 3 1 2 2 2 2 1 2 2 2 1 1 1 2 3 2 1 2 1 1 2 2 1 1 1 3 3 3 1 2 2 1 3 3 2 3
 2 2 2 3 2 2 3 3 2 3 3 2 3 1 3 1 3 1 3 2 1 3 1 1 1 3 3 1 3 1 1 2 1 3 3 3 1
 3 2 2 1 1 2 1 1 3 2 1 2 2 3 3 3 3 2 3 3 1 3 3 1 2 1 2 2 3 1 2 3 1 3 1 2 3
 1 1 3 1 1 2 1 1 2 1 2 2 