In [1]:
import csv
import nltk
from bs4 import BeautifulSoup
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import string
import numpy as np
from numpy.random import default_rng
rng = default_rng(87)

In [2]:
def lowercase(input):
  """
  Returns lowercase text
  """
  return input.lower()

def remove_punctuation(input):
  """
  Returns text without punctuation
  """
  return input.translate(str.maketrans('','', string.punctuation+'–’”“—·'))

def remove_whitespaces(input):
  """
  Returns text without extra whitespaces
  """
  return " ".join(input.split())
  
def remove_html_tags(input):
  """
  Returns text without HTML tags
  """
  soup = BeautifulSoup(input, "html.parser")
  stripped_input = soup.get_text(separator=" ")
  return stripped_input

def tokenize(input):
  """
  Returns tokenized version of text
  """
  return word_tokenize(input)

def remove_stop_words(input):
  """
  Returns text without stop words
  """
  input = word_tokenize(input)
  return [word for word in input if word not in stopwords.words('english') or word == "no" or word == "not"]

def lemmatize(input):
  """
  Lemmatizes input using NLTK's WordNetLemmatizer
  """
  lemmatizer=WordNetLemmatizer()
  input_str=word_tokenize(input)
  new_words = []
  for word in input_str:
    new_words.append(lemmatizer.lemmatize(word))
  return ' '.join(new_words)


def nlp_pipeline(input):
  """
  Function that calls all other functions together to perform NLP on a given text
  """
  return lemmatize(' '.join(remove_stop_words(remove_whitespaces(remove_punctuation(remove_html_tags(lowercase(input)))))))

In [3]:
data = []
sentences = []
i = 0
with open("../comments.csv", newline='') as csvfile:
    reader = csv.reader(csvfile)
    next(reader)
    for row in reader:
        sentences.append(row[0])
        line = nlp_pipeline(row[0]).split()
        if line:
            data.append((i,nlp_pipeline(row[0]).split()))
            i = i + 1

In [4]:
def sample_multinomial(probs):
    return np.where(rng.multinomial(1, probs) == 1)[0][0]

In [5]:
def logsumexp(x):
    c = x.max()
    return c + np.log(np.sum(np.exp(x - c)))

In [6]:
def get_vocab(texts):
    x = set()
    for d in texts:
        for w in d[1]:
            x.add(w)
    return list(x)

In [7]:
def get_cluster_prob(d,m,n,n_w,beta,alpha,V,D):
    num_w = 0
    for w in set(d[1]):
        num_w_d = 0
        for j in range(d[1].count(w)):
            num_w_d = num_w_d + np.log(n_w[w] + beta + (j + 1) - 1)
        num_w = num_w + num_w_d
    
    den_d = 0
    for i in range(len(d[1])):
        den_d = den_d + np.log(n + (len(V) * beta) + (i + 1) - 1)
    return np.log(m) - np.log(D - 1 + alpha) + num_w - den_d

In [8]:
def get_new_cluster_prob(d,beta,alpha,V,D):
    num_w = 0
    for w in set(d[1]):
        num_w_d = 0
        for j in range(d[1].count(w)):
            num_w_d = num_w_d + np.log(beta + (j + 1) - 1)
        num_w = num_w + num_w_d
    
    den_d = 0
    for i in range(len(d[1])):
        den_d = den_d + np.log((len(V) * beta) + (i + 1) - 1)
    
    return np.log(alpha) - np.log(D - 1 + alpha) + num_w - den_d

In [9]:
def initialize_GSDPMM(texts,K_0 = 1):
    V = get_vocab(texts)
    K = K_0
    m_z = [0]*K
    n_z = [0]*K
    n_w_z = []
    for k in range(K):
        n_w_z.append({})
        for v in V:
            n_w_z[k][v] = 0
    z_d = {}
    for d in texts:
        z = sample_multinomial([1/K]*K)
        z_d[d[0]] = z
        m_z[z] = m_z[z] + 1
        for w in d[1]:
            n_z[z] = n_z[z] + 1
            n_w_z[z][w] = n_w_z[z][w] + 1
    return K,m_z,n_z,n_w_z,z_d

In [10]:
def gibbs_GSDPMM(texts,I,K,m_z,n_z,n_w_z,z_d,beta=0.2,alpha=69.4):
    V = get_vocab(texts)
    D = len(texts)
    for i in range(I):
        for d in texts:
            z = z_d[d[0]]
            m_z[z] = m_z[z] - 1
            for w1 in d[1]:
                n_z[z] = n_z[z] - 1
                n_w_z[z][w1] = n_w_z[z][w1] - 1
            if m_z[z] == 0:
                K = K - 1
                del m_z[z]
                del n_z[z]
                del n_w_z[z]
                z_d[d[0]] = -1
                for d1 in texts:
                    if z_d[d1[0]] > z:
                        z_d[d1[0]] = z_d[d1[0]] - 1
            cluster_probs = []
            for k in range(K):
                cluster_probs.append(get_cluster_prob(d,m_z[k],n_z[k],n_w_z[k],beta,alpha,V,D))
            cluster_probs.append(get_new_cluster_prob(d,beta,alpha,V,D))
            cluster_probs = np.array(cluster_probs)
            cluster_probs = np.exp(cluster_probs - logsumexp(cluster_probs))
            z = sample_multinomial(cluster_probs)
            z_d[d[0]] = z
            if z < K:
                m_z[z] = m_z[z] + 1
                for w2 in d[1]:
                    n_z[z] = n_z[z] + 1
                    n_w_z[z][w2] = n_w_z[z][w2] + 1
            else:
                K = K + 1
                m_z.append(0)
                n_z.append(0)
                n_w_z.append({})
                for v in V:
                    n_w_z[z][v] = 0
                m_z[z] = m_z[z] + 1
                for w in d[1]:
                    n_z[z] = n_z[z] + 1
                    n_w_z[z][w] = n_w_z[z][w] + 1
    
    d_probs = []
    for d in texts:
        z = z_d[d[0]]
        m_z[z] = m_z[z] - 1
        for w1 in d[1]:
            n_z[z] = n_z[z] - 1
            n_w_z[z][w1] = n_w_z[z][w1] - 1
        cluster_probs = []
        for k in range(K):
            cluster_probs.append(get_cluster_prob(d,m_z[k],n_z[k],n_w_z[k],beta,alpha,V,D))
        cluster_probs.append(get_new_cluster_prob(d,beta,alpha,V,D))
        cluster_probs = np.array(cluster_probs)
        cluster_probs = np.exp(cluster_probs - logsumexp(cluster_probs))
        d_probs.append(cluster_probs)
    return K, d_probs

In [11]:
def GSDPMM(texts,I = 5,K_0 = 1,beta = 0.02, alpha = 69.4):
    K,m_z,n_z,n_w_z,z_d = initialize_GSDPMM(texts,K_0)
    return gibbs_GSDPMM(texts,I,K,m_z,n_z,n_w_z,z_d,beta,alpha)

In [12]:
K, d_probs = GSDPMM(data,I = 40, K_0 = 9, beta = 0.1, alpha = 0.05*len(data))

  return np.log(m) - np.log(D - 1 + alpha) + num_w - den_d


In [13]:
K

19

In [14]:
for d in d_probs:
    print(d,"\n")

[1.22946909e-62 7.50739005e-79 2.74266293e-70 1.12599993e-89
 4.54322771e-80 1.00000000e+00 5.53445119e-77 6.08013869e-86
 4.53770251e-75 4.97096610e-84 1.36454117e-84 9.43078924e-82
 1.42094945e-82 2.90171236e-98 8.25546274e-92 6.59514185e-91
 8.31534691e-91 1.96211484e-91 1.10496486e-91 2.77213304e-91] 

[2.44988583e-62 1.49595046e-78 5.46513216e-70 2.24370933e-89
 9.05300451e-80 1.00000000e+00 1.10281533e-76 1.21155105e-85
 9.04199478e-75 9.90533194e-84 2.71903549e-84 1.87921414e-81
 2.83143673e-82 5.78205998e-98 1.64501421e-91 1.31417249e-90
 1.65694695e-90 3.90978301e-91 2.20179407e-91 5.52385540e-91] 

[6.28143459e-12 1.33071972e-25 6.29144464e-22 5.26708269e-36
 1.29009643e-25 1.00000000e+00 1.09239147e-24 2.06531966e-32
 4.39268291e-30 3.30832894e-31 5.03717654e-22 5.73838233e-29
 1.64516001e-32 2.50147848e-38 2.94536896e-35 6.05621350e-34
 5.09155097e-33 2.76286330e-34 4.30345003e-36 2.39788443e-33] 

[4.70779584e-62 2.87467656e-78 1.05020104e-69 4.31159907e-89
 1.73966054e-79

 0.00000000e+00 0.00000000e+00 1.95967990e-14 9.80900295e-12] 

[1.00000000e+00 2.18999872e-25 7.01196789e-14 2.43526715e-30
 1.53494963e-23 1.06534784e-15 4.59354865e-22 2.78608682e-32
 6.27062607e-23 0.00000000e+00 1.16992190e-25 8.70845763e-29
 2.09475830e-35 1.43242406e-33 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 6.33760740e-33 2.89085426e-32] 

[7.86113382e-01 3.66755009e-10 2.13737143e-01 1.76189146e-12
 1.49364493e-04 5.95522746e-08 4.46143159e-08 8.89851404e-14
 2.65418717e-09 0.00000000e+00 3.32255903e-09 5.67454513e-12
 3.38032565e-13 2.39689126e-12 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 7.42296714e-14 7.52777542e-14] 

[6.04144272e-04 6.22027406e-07 9.85758540e-01 1.46439191e-07
 1.29731605e-02 8.60233661e-08 6.07866663e-04 5.50349931e-10
 5.06184756e-05 0.00000000e+00 1.64122162e-07 4.59427142e-06
 9.43806388e-09 6.30766227e-10 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 2.77501861e-10 4.67311324e-08] 

[5.12715171e