In [1]:
import gensim
import numpy as np 
import pandas as pd

In [2]:
from collections import Counter

In [3]:
from nltk.tokenize import RegexpTokenizer

In [48]:
import pickle 

In [50]:
from sklearn.metrics.pairwise import cosine_similarity

In [4]:
# Load Google's pre-trained Word2Vec model.
model = gensim.models.KeyedVectors.load_word2vec_format('GoogleNews-vectors-negative300.bin', binary=True)  

## Parsing raw reddit posts

In [5]:
import csv
import nltk
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import re

nltk.download('wordnet')
lemmatizer = WordNetLemmatizer()
# print(lemmatizer.lemmatize("cats"))

nltk.download('stopwords')
stop_words = set(stopwords.words('english'))
tokenizer = RegexpTokenizer(r'\w+')

def parse_reddit_csv(filename):
    print("Reading from", filename)
    csv_cols = []
    frequencies = {}
    with open(filename) as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            #remove numbers 
            row['selftext'] = re.sub(r'\d+', '', row['selftext'])
            row['title'] = re.sub(r'\d+', '', row['title'])
            # Tokenize the post text (selftext) and post title
            #remove punctuations 
            post_tokens = tokenizer.tokenize(row['selftext'])
            title_tokens = tokenizer.tokenize(row['title'])
            # Filter out stopwords
            post_tokens = [w for w in post_tokens if not w in stop_words]
            title_tokens = [w for w in title_tokens if not w in stop_words]
            # Lemmatize the post text (reduce words to word stems i.e. cats->cat, liked->like)
            post_tokens = [lemmatizer.lemmatize(w, 'n') for w in post_tokens]
            post_tokens = [lemmatizer.lemmatize(w, 'v') for w in post_tokens]
            title_tokens = [lemmatizer.lemmatize(w, 'n') for w in title_tokens]
            title_tokens = [lemmatizer.lemmatize(w, 'v') for w in title_tokens]
            csv_cols.append({'author': row['author'],
                             'selftext': post_tokens,
                             'title': title_tokens,
                            'post_id': row['id']})
            # TODO need to collect frequencies of words in the entire corpus
            # TODO update frequencies mapping from word->count and also get a sum
    return csv_cols, frequencies

[nltk_data] Downloading package wordnet to /Users/apple/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to /Users/apple/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [6]:
parsed = parse_reddit_csv('final_proj_data_preprocessed_1000sample.csv')

Reading from final_proj_data_preprocessed_1000sample.csv


## Calculate word embeddings

In [40]:
sen_emb = {}
for i in range(len(parsed[0])):
    counts = Counter(parsed[0][i]['selftext']).items()
    freq = pd.DataFrame(counts)
    freq = freq.rename(columns={0: "word", 1: 'freq'})
    #inverse relative frequency as weight 
    freq['inv_rfreq'] = freq['freq'].sum()/freq['freq']
    unknowns = []
    emb_dict = {}
    for w in freq['word'].to_list():
        try:
            emb = model[w]
            emb_dict.update({w:emb})
        except:
            unknowns.append(w)
    emb_value = pd.DataFrame(emb_dict).transpose().reset_index()
    emb_value = emb_value.rename(columns={'index': "word"})
    emb_value_list = emb_value.iloc[:, 1:301].mul(freq['inv_rfreq'], axis = 0).sum().to_list()
    sen_emb.update({parsed[0][i]['post_id']:emb_value_list})       

In [49]:
with open('sample1000_emb.pickle', 'wb') as handle:
    pickle.dump(sen_emb, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Calculate Pairewise Cosine Similarity 

In [63]:
sen_emb_arr = np.array(list(sen_emb.values()))

In [66]:
sim_mat = cosine_similarity(sen_emb_arr,sen_emb_arr)

In [67]:
sim_mat

array([[1.        , 0.88455658, 0.79086206, ..., 0.86600266, 0.85997844,
        0.82656023],
       [0.88455658, 1.        , 0.86886817, ..., 0.92437809, 0.9194565 ,
        0.94120198],
       [0.79086206, 0.86886817, 1.        , ..., 0.83108216, 0.83927445,
        0.84625584],
       ...,
       [0.86600266, 0.92437809, 0.83108216, ..., 1.        , 0.89908772,
        0.86744328],
       [0.85997844, 0.9194565 , 0.83927445, ..., 0.89908772, 1.        ,
        0.90656519],
       [0.82656023, 0.94120198, 0.84625584, ..., 0.86744328, 0.90656519,
        1.        ]])