In [1]:
% matplotlib inline
import matplotlib.pyplot as plt
from __future__ import division
import numpy as np
import pandas as pd
import gensim.corpora as corpora
import json
import random
import sys
import copy
import nltk
sys.path.append('..')
from helpers.glove_neighbors import *
from gensim.utils import SaveLoad

In [2]:
config = json.load(open('../config.json', 'r'))
MALLET = config['MALLET']
INPUT_DIR = config['INPUT_DIR']
OUTPUT_DIR = config['OUTPUT_DIR']
TWEET_DIR = config['TWEET_DIR']
RNG = random.Random()
RNG.seed(config['SEED'])
events = open(INPUT_DIR + 'event_names.txt', 'r').read().splitlines()
K_VALS = range(6, 11)
sno = nltk.stem.SnowballStemmer('english')

In [3]:
LOWEST_PERCENTILE = 5

# Word intrusion

In [4]:
NO_SAMPLES = 400

In [5]:
# use the same set of topic ids across the two models
SAMPLE_TOPIC_IDS = {num_topics: RNG.choices(range(num_topics), k=NO_SAMPLES) for num_topics in K_VALS}

In [6]:
def get_furthest_other_closest(furthest, closest):
    # keep words that are in `furthest` but also in `closest` for another cluster
    for i, words in furthest.items():
        closest_other = set()
        for j, words_other in closest.items():
            if j == i:
                continue
            closest_other |= words_other

        new_list = []
        for w in words:
            if w in closest_other:
                new_list.append(w)
        furthest[i] = new_list
    return furthest

In [7]:
def generate_samples(k, closest_5, furthest, dicts):
    for i in range(NO_SAMPLES):
        d = {}
        topic_id = SAMPLE_TOPIC_IDS[k][i]  # topic_id for sample
        keywords = RNG.sample(closest_10[topic_id], 5).copy()
        sampled_word = RNG.choice(furthest[topic_id])
        
        # construct sample
        keywords.append(sampled_word)
        shuffled_idx = list(range(len(keywords)))
        RNG.shuffle(shuffled_idx)
        odd_idx = shuffled_idx.index(5)
        d['odd_idx'] = odd_idx
        for j, idx in enumerate(shuffled_idx):
            d['word_' + str(j)] = keywords[idx]
        d['topic'] = topic_id
        d['no_topics'] = k
        d['sample_idx'] = i
        dicts.append(d)
    return dicts

## kmeans model

In [8]:
glove = pd.read_csv(OUTPUT_DIR+'glove.50d.csv', sep='\t', index_col=0)

In [9]:
stopwords = set([sno.stem(w) for w in open(INPUT_DIR + 'stopwords.txt', 'r').read().splitlines()])
glove = glove[~glove.index.isin(stopwords)]

In [10]:
V = len(glove)
V_lowest = int(V / 100 * LOWEST_PERCENTILE)
V_lowest

83

In [11]:
dicts = []
for k in K_VALS:
    print(k, 'clusters')
    means = np.load(OUTPUT_DIR + 'cluster_'+str(k)+'_means.npy')
    
    # get closest words to each cluster
    closest_10 = {i: list(neighbors_vector(m, glove).head(10).index) for i, m in enumerate(means)} # this is a list
    closest = {i: set(neighbors_vector(m, glove).head(V_lowest).index) for i, m in enumerate(means)}
    
    # get furthest words to each cluster within LOWEST_PERCENTILE percentile
    furthest = {i: list(neighbors_vector(m, glove).tail(V_lowest).index) for i, m in enumerate(means)}
    furthest = get_furthest_other_closest(furthest, closest)
    
    # generate samples
    dicts = generate_samples(k, closest_10, furthest, dicts)
kmeans_df = pd.DataFrame(dicts, index = range(len(dicts)))

6 clusters
7 clusters
8 clusters
9 clusters
10 clusters


## Mallet

In [12]:
topic_words = SaveLoad.load(OUTPUT_DIR + 'ldamallet_model_6.pickle').get_topics()
V = topic_words.shape[1]
V_lowest = int(V / 100 * LOWEST_PERCENTILE)
V_lowest

83

In [13]:
dicts = []
for k in K_VALS:
    print(k, 'clusters')
    ldamallet = SaveLoad.load(OUTPUT_DIR + 'ldamallet_model_' + str(k) + '.pickle')
    topic_words = ldamallet.get_topics()
    id2word = ldamallet.id2word
    
    # get closest words to each cluster
    closest_10 = {i: [id2word.id2token[idx] for idx in row.argsort()[::-1][:10]] for i, row in enumerate(topic_words)} # list
    closest = {i: set([id2word.id2token[idx] for idx in row.argsort()[::-1][:V_lowest]]) for i, row in enumerate(topic_words)}  # set
    
    # get furthest words to each cluster within LOWEST_PERCENTILE percentile
    furthest = {i: [id2word.id2token[idx] for idx in row.argsort()[:V_lowest]] for i, row in enumerate(topic_words)}
    furthest = get_furthest_other_closest(furthest, closest)
    
    # generate samples
    dicts = generate_samples(k, closest_10, furthest, dicts)
mallet_df = pd.DataFrame(dicts, index = range(len(dicts)))

6 clusters
7 clusters
8 clusters
9 clusters
10 clusters


In [14]:
mallet_df

Unnamed: 0,no_topics,odd_idx,sample_idx,topic,word_0,word_1,word_2,word_3,word_4,word_5
0,6,5,0,3,bernardino,orlando,obama,terror,terrorist,fear
1,6,4,1,0,student,texa,shoot,trump,pattern,year
2,6,3,2,1,airport,shot,offic,protest,dead,suspect
3,6,3,3,1,offic,gunman,polic,jesus,shoot,dead
4,6,2,4,4,time,stop,mall,church,law,make
5,6,1,5,4,make,condemn,violenc,good,church,talk
6,6,5,6,5,today,shoot,mass,famili,victim,cop
7,6,4,7,0,trump,student,school,day,citizen,year
8,6,4,8,2,peopl,cop,guy,black,parent,kill
9,6,5,9,0,shoot,trump,high,school,parkland,cover


In [15]:
kmeans_df

Unnamed: 0,no_topics,odd_idx,sample_idx,topic,word_0,word_1,word_2,word_3,word_4,word_5
0,6,0,0,3,photo,problem,fix,#guncontolnow,bad,yeah
1,6,0,1,0,impact,#blacklivesmatt,thug,blm,racist,#blm
2,6,4,2,1,detain,fatal,multipl,#breakingnew,honest,updat
3,6,3,3,1,updat,unconfirm,fatal,agre,#updat,#break
4,6,2,4,4,southern,david,effect,veteran,calif,california
5,6,3,5,4,gunman,calif,identifi,#blacklivesmatt,david,bar
6,6,3,6,5,observ,honor,vigil,knife,honour,candlelight
7,6,5,7,0,#blacklivesmatt,label,#blm,racism,radic,heal
8,6,3,8,2,#prayer,affect,sadden,reuter,faculti,prayer
9,6,4,9,0,blm,racist,label,supremacist,los,thug


In [16]:
mallet_df['model'] = 'mallet'
kmeans_df['model'] = 'kmeans'
concat = pd.concat([mallet_df, kmeans_df], ignore_index=True)

In [17]:
concat

Unnamed: 0,no_topics,odd_idx,sample_idx,topic,word_0,word_1,word_2,word_3,word_4,word_5,model
0,6,5,0,3,bernardino,orlando,obama,terror,terrorist,fear,mallet
1,6,4,1,0,student,texa,shoot,trump,pattern,year,mallet
2,6,3,2,1,airport,shot,offic,protest,dead,suspect,mallet
3,6,3,3,1,offic,gunman,polic,jesus,shoot,dead,mallet
4,6,2,4,4,time,stop,mall,church,law,make,mallet
5,6,1,5,4,make,condemn,violenc,good,church,talk,mallet
6,6,5,6,5,today,shoot,mass,famili,victim,cop,mallet
7,6,4,7,0,trump,student,school,day,citizen,year,mallet
8,6,4,8,2,peopl,cop,guy,black,parent,kill,mallet
9,6,5,9,0,shoot,trump,high,school,parkland,cover,mallet


In [18]:
concat.to_csv(OUTPUT_DIR + 'word_intrusion_data.csv', index=False)