In [1]:
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.decomposition import LatentDirichletAllocation as LDA
from sklearn.decomposition import NMF, PCA
from sklearn.linear_model import Ridge, LogisticRegression, Lasso
from sklearn.metrics import mean_squared_error as mse, roc_auc_score as roc, accuracy_score as acc, log_loss
from sklearn.neural_network import MLPClassifier, MLPRegressor
import numpy as np
import pandas as pd
from data.dataset import TextResponseDataset
import causal_attribution
import util
from scipy.sparse import csr_matrix
from importlib import reload
import data.dataset as ds
from model.topic_model import TopicModel
from model.adjusted_hstm import HeterogeneousSupervisedTopicModel
from model.model_trainer import ModelTrainer
from torch.utils.data import DataLoader
import evaluation.evaluator as ev
from evaluation.evaluator import Evaluator
import itertools as it
import seaborn as sns
import os

classification_settings = {'peerread', 'amazon_binary', 'yelp','framing_corpus'}

In [2]:
reload(ds)

dataset = 'yelp'

framing_topic = 'deathpenalty'

if dataset == 'amazon':
    datafile = '../dat/reviews_Office_Products_5.json'
elif dataset == 'amazon_binary':
    datafile = '../dat/reviews_Grocery_and_Gourmet_Food_5.json'
elif dataset == 'yelp':
    datafile = '../dat/yelp_review_polarity_csv/train.csv'
elif dataset == 'peerread':
    datafile = '../dat/peerread_abstracts.csv'
elif dataset == 'framing_corpus':
    datafile = '../dat/framing/'
else:
    datafile = '../dat/cs_papers.gz'

if dataset == 'framing_corpus':
    proc_file = '../dat/proc/' + dataset + '_' + framing_topic + '_proc.npz'
else:
    proc_file = '../dat/proc/' + dataset + '_proc.npz'

components = {'amazon':30, 
              'semantic_scholar':50, 
              'peerread':50, 'yelp':30, 
              'amazon_binary':20, 
              'framing_corpus':10
             }
text_dataset = ds.TextResponseDataset(dataset, 
                                      datafile, 
                                      proc_file, 
                                      use_bigrams=False,
                                      framing_topic=framing_topic)

counts = text_dataset.counts
labels= text_dataset.labels
vocab= text_dataset.vocab
docs = text_dataset.docs

n_components=components[dataset]
num_documents = counts.shape[0]
n_components, num_documents, counts.shape[1]

(30, 19969, 6057)

In [3]:
vocab_size=counts.shape[1]

stm = HeterogeneousSupervisedTopicModel(n_components, 
                                        vocab_size, 
                                        num_documents,
                                        response_model='stm')
hstm = HeterogeneousSupervisedTopicModel(n_components, 
                                        vocab_size, 
                                        num_documents,
                                        response_model='hstm-all')

trainer_stm = ModelTrainer(stm,
                       use_pretrained=True,
                       do_pretraining_stage=False,
                       do_finetuning=True,
                       model_name='stm',
                       load=True,
                        model_file='../out/model/stm.yelp.0.model')

trainer_hstm = ModelTrainer(hstm,
                       use_pretrained=True,
                       do_pretraining_stage=False,
                       do_finetuning=True,
                       model_name='hstm-all',
                       load=True,
                        model_file='../out/model/hstm-all.yelp.0.model')


trainer_stm.train(None)
trainer_hstm.train(None)

In [23]:
reload(ev)

evaluator_stm = ev.Evaluator(stm, 
                      text_dataset.vocab,
                      text_dataset.counts, 
                      text_dataset.labels, 
                      text_dataset.docs,
                      model_name='stm')

stm_topics = evaluator_stm.get_topics()


evaluator_hstm = ev.Evaluator(hstm, 
                      text_dataset.vocab,
                      text_dataset.counts, 
                      text_dataset.labels, 
                      text_dataset.docs,
                      model_name='hstm-all')

hstm_topics = evaluator_hstm.get_topics()

In [24]:
evaluator_stm.visualize_topics(format_pretty=True)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Topic 0: sushi roll bar place people get like
Topic 1: coffee staff location starbucks cafe ha espresso
Topic 2: burger fry wa potato sweet bacon cheese
Topic 3: last night happy hour hour time went time wa
Topic 4: der nicht und die ich ist auch
Topic 5: egg breakfast pancake toast wa bacon benedict
Topic 6: nail massage salon pedicure gel wa pedi
Topic 7: customer order location service cashier manager counter
Topic 8: show wa car but get see told
Topic 9: le en de la qui une pa
Topic 10: minute u waited order finally wait table
Topic 11: bar club dance night girl wa dj
Topic 12: sandwich pie wa sandwich wa soup chocolate soup wa
Topic 13: great place food love always good price
Topic 14: wa roll dish sushi but chef sauce
Topic 15: wa food horrible mexican tasted like bland like
Topic 16: flight airline ba

'Topic 0&sushi&roll&bar&place&people&get&like\\\\\nTopic 1&coffee&staff&location&starbucks&cafe&ha&espresso\\\\\nTopic 2&burger&fry&wa&potato&sweet&bacon&cheese\\\\\nTopic 3&last&night&happy hour&hour&time went&time&wa\\\\\nTopic 4&der&nicht&und&die&ich&ist&auch\\\\\nTopic 5&egg&breakfast&pancake&toast&wa&bacon&benedict\\\\\nTopic 6&nail&massage&salon&pedicure&gel&wa&pedi\\\\\nTopic 7&customer&order&location&service&cashier&manager&counter\\\\\nTopic 8&show&wa&car&but&get&see&told\\\\\nTopic 9&le&en&de&la&qui&une&pa\\\\\nTopic 10&minute&u&waited&order&finally&wait&table\\\\\nTopic 11&bar&club&dance&night&girl&wa&dj\\\\\nTopic 12&sandwich&pie&wa&sandwich wa&soup&chocolate&soup wa\\\\\nTopic 13&great&place&food&love&always&good&price\\\\\nTopic 14&wa&roll&dish&sushi&but&chef&sauce\\\\\nTopic 15&wa&food&horrible&mexican&tasted like&bland&like\\\\\nTopic 16&flight&airline&bag&plane&airport&fly&charge\\\\\nTopic 17&wa&steak&steak wa&appetizer&wine&great&filet\\\\\nTopic 18&dog&movie&theater

In [25]:
evaluator_hstm.visualize_topics(format_pretty=True)

------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
Topic 0: sushi people roll night place one but
Topic 1: coffee ha shop cafe but staff one
Topic 2: burger fry cheese ice sweet but cream
Topic 3: night last last time time went first time time happy hour
Topic 4: ist und ich da auch man nicht
Topic 5: egg toast breakfast benedict biscuit pancake waffle
Topic 6: nail gel massage polish pedicure salon pedi
Topic 7: customer order customer service minute order wa manager service
Topic 8: show car but wa one get said
Topic 9: la le de un est dans je
Topic 10: minute u finally order took waitress waiting
Topic 11: club night dance drink bar girl beer
Topic 12: sandwich sandwich wa soup bread pie turkey chocolate
Topic 13: great always good but pretty good pho place
Topic 14: wa lobster but shrimp pasta piece sauce
Topic 15: bad food tasted like bland wa horrible 

'Topic 0&sushi&people&roll&night&place&one&but\\\\\nTopic 1&coffee&ha&shop&cafe&but&staff&one\\\\\nTopic 2&burger&fry&cheese&ice&sweet&but&cream\\\\\nTopic 3&night&last&last time&time went&first time&time&happy hour\\\\\nTopic 4&ist&und&ich&da&auch&man&nicht\\\\\nTopic 5&egg&toast&breakfast&benedict&biscuit&pancake&waffle\\\\\nTopic 6&nail&gel&massage&polish&pedicure&salon&pedi\\\\\nTopic 7&customer&order&customer service&minute&order wa&manager&service\\\\\nTopic 8&show&car&but&wa&one&get&said\\\\\nTopic 9&la&le&de&un&est&dans&je\\\\\nTopic 10&minute&u&finally&order&took&waitress&waiting\\\\\nTopic 11&club&night&dance&drink&bar&girl&beer\\\\\nTopic 12&sandwich&sandwich wa&soup&bread&pie&turkey&chocolate\\\\\nTopic 13&great&always&good&but&pretty good&pho&place\\\\\nTopic 14&wa&lobster&but&shrimp&pasta&piece&sauce\\\\\nTopic 15&bad&food&tasted like&bland&wa&horrible&tasted\\\\\nTopic 16&flight&airline&plane&la&fly&airport&ticket\\\\\nTopic 17&steak&wa&wine&but&dining&dinner&service wa\

In [28]:
n_words = 10
stm_subset = stm_topics[:,:n_words]
hstm_subset = hstm_topics[:,:n_words]

for k in range(hstm_subset.shape[0]):
    print(k, '*'*60)
    for w_idx in hstm_subset[k,:]:
        if w_idx not in set(list(stm_subset[k,:])):
            print(text_dataset.vocab[int(w_idx)])

0 ************************************************************
night
one
but
eat
1 ************************************************************
but
one
bakery
2 ************************************************************
cream
onion
burger wa
3 ************************************************************
last time
donut
happy
4 ************************************************************
e
war
5 ************************************************************
biscuit
gravy
hash
6 ************************************************************
polish
spa
7 ************************************************************
customer service
minute
order wa
employee
phone
8 ************************************************************
one
said
hair
people
9 ************************************************************
un
est
dans
je
que
10 ************************************************************
took
waitress
waiting
another
came
11 ************************************************************
beer
b