In [1]:
import os 
import sys

sys.path.append('/home/kalkiek/projects/reddit-political-affiliation/')

import itertools
import random
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from tqdm.notebook import tqdm

from src.data.make_dataset import build_dataset
from src.models.word2vec.User2Subreddit import User2Subreddit
# from src.models.word2vec.predict_model import predict_user_affiliations, top_n_similar_embeddings, \
#     save_similar_embeddings_to_tsv

from sklearn.metrics import auc, roc_curve

# Testing/Validating the Model

### Load in the datset

In [2]:
year_month = '2019-01'

network_path = '/shared/0/projects/reddit-political-affiliation/data/bipartite-networks/' + year_month + '_filtered.tsv'
flair_directory =  '/shared/0/projects/reddit-political-affiliation/data/flair-affiliations/' + year_month + '.tsv'

dataset, training, validation, pol_validation, vocab = build_dataset(network_path, flair_directory)


Building vocab from file:   0%|          | 0/26729161 [00:00<?, ?it/s][A
Building vocab from file:   0%|          | 8509/26729161 [00:00<23:36, 18863.38it/s][A
Building vocab from file:   0%|          | 71957/26729161 [00:00<16:41, 26608.65it/s][A
Building vocab from file:   1%|          | 143113/26729161 [00:00<11:50, 37412.76it/s][A
Building vocab from file:   1%|          | 178061/26729161 [00:01<10:06, 43773.80it/s][A
Building vocab from file:   1%|          | 244398/26729161 [00:01<07:15, 60814.15it/s][A
Building vocab from file:   1%|          | 309273/26729161 [00:01<05:16, 83521.86it/s][A
Building vocab from file:   1%|▏         | 384769/26729161 [00:01<03:51, 113915.83it/s][A
Building vocab from file:   2%|▏         | 451402/26729161 [00:01<02:53, 151627.28it/s][A
Building vocab from file:   2%|▏         | 520306/26729161 [00:01<02:12, 197942.32it/s][A
Building vocab from file:   2%|▏         | 581999/26729161 [00:01<01:45, 248213.60it/s][A
Building vocab from file

Building vocab from file:  23%|██▎       | 6124150/26729161 [00:12<00:33, 613579.87it/s][A
Building vocab from file:  23%|██▎       | 6192430/26729161 [00:12<00:32, 630453.62it/s][A
Building vocab from file:  23%|██▎       | 6260437/26729161 [00:12<00:32, 636838.92it/s][A
Building vocab from file:  24%|██▎       | 6329852/26729161 [00:12<00:31, 653012.19it/s][A
Building vocab from file:  24%|██▍       | 6402412/26729161 [00:12<00:30, 673213.92it/s][A
Building vocab from file:  24%|██▍       | 6471671/26729161 [00:13<00:30, 671634.33it/s][A
Building vocab from file:  24%|██▍       | 6540192/26729161 [00:13<00:30, 670240.76it/s][A
Building vocab from file:  25%|██▍       | 6608167/26729161 [00:13<00:30, 667674.57it/s][A
Building vocab from file:  25%|██▍       | 6675603/26729161 [00:13<00:30, 655393.85it/s][A
Building vocab from file:  25%|██▌       | 6742547/26729161 [00:13<00:30, 659545.13it/s][A
Building vocab from file:  25%|██▌       | 6808871/26729161 [00:13<00:30, 660057

Building vocab from file:  45%|████▍     | 11899132/26729161 [00:25<00:26, 563674.57it/s][A
Building vocab from file:  45%|████▍     | 11956941/26729161 [00:25<00:26, 567922.10it/s][A
Building vocab from file:  45%|████▍     | 12016643/26729161 [00:25<00:25, 576348.20it/s][A
Building vocab from file:  45%|████▌     | 12078907/26729161 [00:25<00:24, 589495.46it/s][A
Building vocab from file:  45%|████▌     | 12138415/26729161 [00:25<00:24, 589680.07it/s][A
Building vocab from file:  46%|████▌     | 12204852/26729161 [00:25<00:23, 610260.73it/s][A
Building vocab from file:  46%|████▌     | 12271677/26729161 [00:25<00:23, 626569.73it/s][A
Building vocab from file:  46%|████▌     | 12336723/26729161 [00:25<00:22, 633550.23it/s][A
Building vocab from file:  46%|████▋     | 12402197/26729161 [00:25<00:22, 639759.68it/s][A
Building vocab from file:  47%|████▋     | 12468221/26729161 [00:26<00:22, 645768.39it/s][A
Building vocab from file:  47%|████▋     | 12532978/26729161 [00:26<00

Building vocab from file:  65%|██████▌   | 17436291/26729161 [00:37<00:15, 583808.84it/s][A
Building vocab from file:  65%|██████▌   | 17500685/26729161 [00:37<00:15, 600634.80it/s][A
Building vocab from file:  66%|██████▌   | 17560982/26729161 [00:39<01:39, 92241.94it/s] [A
Building vocab from file:  66%|██████▌   | 17615348/26729161 [00:39<01:14, 122841.70it/s][A
Building vocab from file:  66%|██████▌   | 17675130/26729161 [00:39<00:56, 161284.64it/s][A
Building vocab from file:  66%|██████▋   | 17734260/26729161 [00:39<00:43, 206291.47it/s][A
Building vocab from file:  67%|██████▋   | 17796610/26729161 [00:39<00:34, 258103.57it/s][A
Building vocab from file:  67%|██████▋   | 17858296/26729161 [00:39<00:28, 312653.61it/s][A
Building vocab from file:  67%|██████▋   | 17922204/26729161 [00:39<00:23, 369231.30it/s][A
Building vocab from file:  67%|██████▋   | 17981703/26729161 [00:40<00:21, 416250.83it/s][A
Building vocab from file:  67%|██████▋   | 18042044/26729161 [00:40<00

Building vocab from file:  85%|████████▍ | 22696525/26729161 [00:53<00:09, 416622.89it/s][A
Building vocab from file:  85%|████████▌ | 22751796/26729161 [00:53<00:08, 449850.65it/s][A
Building vocab from file:  85%|████████▌ | 22806213/26729161 [00:53<00:08, 472602.39it/s][A
Building vocab from file:  86%|████████▌ | 22860380/26729161 [00:53<00:08, 474962.73it/s][A
Building vocab from file:  86%|████████▌ | 22915770/26729161 [00:53<00:07, 496171.68it/s][A
Building vocab from file:  86%|████████▌ | 22971516/26729161 [00:53<00:07, 513092.33it/s][A
Building vocab from file:  86%|████████▌ | 23027264/26729161 [00:53<00:07, 525648.07it/s][A
Building vocab from file:  86%|████████▋ | 23088271/26729161 [00:53<00:06, 548414.88it/s][A
Building vocab from file:  87%|████████▋ | 23144740/26729161 [00:53<00:06, 552222.34it/s][A
Building vocab from file:  87%|████████▋ | 23201108/26729161 [00:53<00:06, 552694.77it/s][A
Building vocab from file:  87%|████████▋ | 23257180/26729161 [00:54<00

Length of vocab: 5178863
User count: 5120865
Subreddit count: 57998
User to politic counts: 2633
[('Cord_inate8', Counter({'Republican': 2})), ('error404brain', Counter({'Democrat': 1})), ('OTIS_is_king', Counter({'Democrat': 1})), ('Nesano', Counter({'Republican': 1})), ('SinisterPaige', Counter({'Republican': 1})), ('grubas', Counter({'Republican': 1})), ('DrPiccoloPhD', Counter({'Republican': 1})), ('Hillarys_cellmate', Counter({'Republican': 2})), ('guanaco55', Counter({'Republican': 2})), ('nycola', Counter({'Democrat': 1}))]
Saw political affiliations for 2632 users
User to politics training size: {}: 2369
User to politics validation size: {}: 263


Converting data to PyTorch: 100%|██████████| 5120865/5120865 [08:26<00:00, 10115.12it/s]


Train size: 144337470 Validation size: 16037496


In [7]:
# We'll need these later
word_to_ix = {word: i for i, word in enumerate(vocab)}
all_subreddits = {v for v in vocab if v[:2] == 'r/' and v[2:4] != 'u_'}
print("# of subreddits: " + str(len(all_subreddits)))

# of subreddits: 55204


### Load in the model

In [3]:
PATH = '/shared/0/projects/reddit-political-affiliation/working-dir/word2vec-outputs/' + year_month + '/9.pt'
embedding_dim = 50

# Sorry for the hardcoding ... will update later
model = User2Subreddit(5138256, embedding_dim, 80881)
model.load_state_dict(torch.load(PATH))
model.eval()

User2Subreddit(
  (u_embeddings): Embedding(5138256, 50)
  (v_embeddings): Embedding(80881, 50)
  (political_layer): Linear(in_features=50, out_features=1, bias=True)
  (before_pol_dropout): Dropout(p=0.5, inplace=False)
)

### Predict on the validation set

In [4]:
user_ids, pol_labels = [], []

for user, pol_label in pol_validation.items():
    try:
        # User subreddit dataset spans 1 month. Political data spans the year. Some users might not be present
        user_ids.append(dataset.user_to_idx[user])
        pol_labels.append(pol_label)
    except KeyError:
        pass

user_ids = torch.LongTensor(user_ids)
pol_labels = torch.FloatTensor(pol_labels)

emb_p = model.u_embeddings(user_ids)
political_predictions = model.political_layer(emb_p)
political_predictions = torch.sigmoid(political_predictions)

In [5]:
# Quick clean up

preds = []
for val in political_predictions.detach().numpy():
    if val[0] >= 0.5:
        preds.append(1)
    else:
        preds.append(0)

labels = pol_labels.detach().numpy().astype(int)

In [6]:
from sklearn.metrics import accuracy_score


accuracy_score(labels, preds)

0.7735042735042735

### Similar Embeddings 

In [15]:
sample_subreddits = ['r/nba', 'r/CryptoCurrency', 'r/Conservative', 'r/Liberal', 'r/AskReddit',
                     'r/Aww', 'r/Games', 'r/Hunting', 'r/Feminism', 'r/The_Donald',
                     'r/lawnmowers', 'r/juul', 'r/teenagers']

def top_n_similar_embeddings(model, subreddit, all_subreddits, word_to_ix, n):
    cosine_sims = {}
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)

    print('looking up ', subreddit)

    try:
        sub_tensor = torch.tensor([word_to_ix[subreddit]], dtype=torch.long)
    except KeyError:
        return
    count = 0
    for sub in all_subreddits:
        ix = word_to_ix[sub]
        lookup_tensor = torch.tensor([ix], dtype=torch.long)
        try :
            cos_result = cos(model.u_embeddings(sub_tensor), model.u_embeddings(lookup_tensor))
            cosine_sims[sub] = cos_result
        except Exception:
            count += 1
    print(count)        

    # Sort and grab the top N
    cosine_sims = {k: v for k, v in sorted(cosine_sims.items(), key=lambda item: item[1], reverse=True)}
    top_results = dict(itertools.islice(cosine_sims.items(), n))

    # Spit out the results to the console
    print("Top {} similar embeddings for the subreddit: {}".format(n, subreddit))
    for sub, score in top_results.items():
        print(sub, score)

    return top_results


for sample_sub in sample_subreddits:
    top_n_similar_embeddings(model, sample_sub, all_subreddits, word_to_ix, n=10)

looking up  r/nba
453


NameError: name 'itertools' is not defined