In [6]:
# pip install sentence_transformers
# pip install nltk

In [7]:
# 1: use sentence transformer
# https://github.com/UKPLab/sentence-transformers/blob/master/examples/applications/clustering/kmeans.py
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans

In [11]:
import csv
import numpy as np
import pandas as pd
from pandas import read_csv

In [12]:
# imports for NLP
import nltk
nltk.download('punkt')
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('omw-1.4')
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from dateutil import parser
import string

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation as LDA

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...


In [79]:
history = []
with open('key_search_terms_cynthia.csv', 'r') as csvfile:
    datareader = csv.reader(csvfile)
    for row in datareader:
        # prints normalized term
        history.append(row[3])

In [80]:
# NLP functions for cleaning text

def remove_punctuation(input):
  return input.translate(str.maketrans('','', string.punctuation))

def remove_whitespaces(input):
  return " ".join(input.split())

def tokenize(input):
  return word_tokenize(input)

def remove_stop_words(input):
  input = word_tokenize(input)
  return [word for word in input if word not in stopwords.words('english')]

def lemmatize(input):
  lemmatizer=WordNetLemmatizer()
  input_str=word_tokenize(input)
  new_words = []
  for word in input_str:
    new_words.append(lemmatizer.lemmatize(word))
  return ' '.join(new_words)

# pipeline for cleaning
def nlp_pipeline(text):
  return lemmatize(' '.join(remove_stop_words(remove_whitespaces(remove_punctuation(text)))))

In [81]:
def find_topics(word_cluster):
    text = ' '.join(word_cluster)
    # text = nlp_pipeline(question_body)
    # print("cleaned text: ", text)
    count_vectorizer = CountVectorizer(stop_words='english')
    count_data = count_vectorizer.fit_transform([text])

    number_of_tags = 1
    lda = LDA(n_components=1, n_jobs=-1)
    lda.fit(count_data)

    words = count_vectorizer.get_feature_names_out()
    # Get topics from model
    topics = [[words[i] for i in topic.argsort()[:-number_of_tags - 1:-1]] for (topic_idx, topic) in enumerate(lda.components_)]

    topics = np.array(topics).ravel()
    return topics

In [82]:
import re
def clean_text(text):
    # remove numbers
    text_nonum = re.sub(r'\d+', '', text)
    # remove punctuations and convert characters to lower case
    text_nopunct = "".join([char.lower() for char in text_nonum if char not in string.punctuation]) 
    # substitute multiple whitespace with single whitespace
    # Also, removes leading and trailing whitespaces
    text_no_doublespace = re.sub('\s+', ' ', text_nopunct).strip()
    return text_no_doublespace

In [83]:
history = [nlp_pipeline(h) for h in history]
history = [clean_text(h) for h in history]

In [84]:
# Perform kmean clustering
# ideally set num_clusters dynamically 

def generate_clusters(num_clusters, phrase_list):
  sentence_emb = embedder.encode(phrase_list)
  clusters = {}

  clustering_model = KMeans(n_clusters=num_clusters)
  clustering_model.fit(sentence_emb)
  cluster_assignment = clustering_model.labels_

  clustered_sentences = [[] for i in range(num_clusters)]
  for sentence_id, cluster_id in enumerate(cluster_assignment):
      clustered_sentences[cluster_id].append(phrase_list[sentence_id])

  for i, cluster in enumerate(clustered_sentences):
    topic = find_topics(cluster)[0]
    cluster = [i for i in cluster if i]
    clusters[topic] = cluster
    sizes[topic] = len(cluster)
  return clusters

In [85]:
sizes = {}
embedder = SentenceTransformer('all-MiniLM-L6-v2')
roots = generate_clusters(10, history)

In [86]:
queue = []
graph = {}
clusters = roots
size_thresh = 20
num_children = 3

# add original parent topics
for topic in roots.keys():
    if sizes[topic] >= size_thresh:
        queue.append(topic)
        graph[topic] = set()

while len(queue) > 0:
  parent_topic = queue.pop(0)
  print('parent', parent_topic)
  parent_cluster = clusters[parent_topic]
  if len(parent_cluster) >= size_thresh and len(set(parent_cluster)) > num_children:
    new_clusters = generate_clusters(num_children, parent_cluster)
    for child_topic in new_clusters.keys():
      if child_topic not in clusters.keys():
        print('child', child_topic)

        # add links between parent + child to graph
        graph[child_topic] = set()
        graph[child_topic].add(parent_topic)
        graph[parent_topic].add(child_topic)

        new_c = new_clusters[child_topic]
        clusters[child_topic] = new_c
        sizes[child_topic] = len(new_c)

        # add child to queue if we re-cluster again
        if sizes[child_topic] >= size_thresh:
          queue.append(child_topic)

parent belt
parent jacket
child thml
parent lyric
child imagen
child sydel
parent cvsz
parent weather
child statue
child summer
parent plantz
child comedrecreationcenterataddamsparkz
child courtyardbymarriottnearallstonbostonmaz
parent harvard
child graphology
child google
parent chicago
child hotel
child near
parent python
child module
parent derrick
child white
child jenisz
child revel
parent imagen
child liang
child nuro
parent sydel
child electric
child aa
parent summer
child festival
parent comedrecreationcenterataddamsparkz
child plantzdatamb
child yoganearcambridgemaz
parent courtyardbymarriottnearallstonbostonmaz
child grocerywestloopz
child quincystbrooklynnyz
parent graphology
child graph
child visualization
parent google
child policy
child ai
parent hotel
parent near
child bar
child place
parent module
child cache
child named
child github
parent white
child wife
parent jenisz
child brandes
child kendallmarriottz
parent revel
child saladnearbyz
child pennstationz
parent nuro


In [87]:
for k in graph.keys():
  graph[k] = list(graph[k])

In [89]:
graph

{'belt': [],
 'jacket': ['thml'],
 'lyric': ['sydel', 'imagen'],
 'cvsz': [],
 'weather': ['summer', 'statue'],
 'plantz': ['courtyardbymarriottnearallstonbostonmaz',
  'comedrecreationcenterataddamsparkz'],
 'harvard': ['graphology', 'google'],
 'chicago': ['near', 'hotel'],
 'python': ['module'],
 'derrick': ['white', 'jenisz', 'revel'],
 'thml': ['jacket'],
 'imagen': ['liang', 'lyric', 'nuro'],
 'sydel': ['aa', 'lyric', 'electric'],
 'statue': ['weather'],
 'summer': ['festival', 'weather'],
 'comedrecreationcenterataddamsparkz': ['plantzdatamb',
  'plantz',
  'yoganearcambridgemaz'],
 'courtyardbymarriottnearallstonbostonmaz': ['plantz',
  'quincystbrooklynnyz',
  'grocerywestloopz'],
 'graphology': ['harvard', 'visualization', 'graph'],
 'google': ['harvard', 'ai', 'policy'],
 'hotel': ['chicago'],
 'near': ['chicago', 'place', 'bar'],
 'module': ['python', 'cache', 'github', 'named'],
 'white': ['derrick', 'wife'],
 'jenisz': ['kendallmarriottz', 'derrick', 'brandes'],
 'revel':

In [90]:
sizes

{'belt': 36,
 'jacket': 2,
 'lyric': 165,
 'cvsz': 40,
 'weather': 25,
 'plantz': 137,
 'harvard': 46,
 'chicago': 131,
 'python': 26,
 'derrick': 137,
 'thml': 9,
 'imagen': 17,
 'sydel': 11,
 'statue': 13,
 'summer': 13,
 'comedrecreationcenterataddamsparkz': 30,
 'courtyardbymarriottnearallstonbostonmaz': 18,
 'graphology': 15,
 'google': 42,
 'hotel': 23,
 'near': 6,
 'module': 21,
 'white': 2,
 'jenisz': 21,
 'revel': 11,
 'liang': 18,
 'nuro': 67,
 'electric': 8,
 'aa': 4,
 'festival': 12,
 'plantzdatamb': 36,
 'yoganearcambridgemaz': 9,
 'grocerywestloopz': 6,
 'quincystbrooklynnyz': 1,
 'graph': 5,
 'visualization': 6,
 'policy': 22,
 'ai': 19,
 'bar': 8,
 'place': 14,
 'cache': 4,
 'named': 11,
 'github': 6,
 'wife': 18,
 'brandes': 13,
 'kendallmarriottz': 12,
 'saladnearbyz': 12,
 'pennstationz': 10,
 'celtic': 10,
 'swift': 14,
 'soulcyclechicagozdatamb': 16,
 'vaticinvestmentsnewyorkzdatamb': 7,
 'urbanoutfitterszdatamb': 13,
 'zoom': 9,
 'trading': 15,
 'goodreads': 10,
 