In [93]:
import pickle
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import random
import pandas as pd

In [2]:
from sentence_transformers import SentenceTransformer
from sentence_transformers.util import cos_sim

from openai.embeddings_utils import get_embedding, cosine_similarity

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def get_embeddings(texts, model_name, progress=True):
    if model_name == 'mpnet':
        return model.encode(texts, batch_size=128,show_progress_bar=progress)
    elif model_name == 'sgpt':
        return sgpt_model.encode(texts, show_progress_bar=progress)
    elif model_name == 'ada':
        return get_ada_embedding(texts, progress)
    elif model_name == 'dw':
        return get_avg_node_embedding(texts)

In [4]:
model = SentenceTransformer('sentence-transformers/stsb-mpnet-base-v2', {'cache_dir':'/data/huggingface_cache'}).to('cuda')

2023-05-06 05:45:11.692353: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-06 05:45:11.844164: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-05-06 05:45:12.559844: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/opt/amazo

Using Maximum Sequence Length:  75


In [5]:
def get_stats(array):
    print("Mean:",np.mean(array))
    print("Percentiles")
    for i in [5,25,50,75,90,99, 99.5, 99.7]:
        print(f"{i} - {round(np.percentile(array, i))}")
    print()

In [81]:
trees, user_data, post_data, chosen_uids, chosen_trees = pickle.load(open('so_politics.pkl','rb'))

In [82]:
train_trees, test_trees = train_test_split(chosen_trees, test_size=0.33, random_state=0)

In [83]:
len(train_trees), len(test_trees)

(975, 481)

In [84]:
all_posts = []
all_te_posts = []
flat_uids = []
all_tag_posts = []
doc_freq=defaultdict(int)
topic_to_user = defaultdict(set)

for rid in train_trees:
    tags = post_data[rid]['tags']
    for pid in trees[rid]:
        if post_data[pid]['user_id'] in chosen_uids:
            title = post_data[pid]['title'] if post_data[pid]['title'] else ''
            all_posts.append(title + ' '+post_data[pid]['text'])
            flat_uids.append(post_data[pid]['user_id'])
            all_tag_posts.append(' - '.join(tags))
            all_te_posts.append( all_posts[-1]+' '+all_tag_posts[-1])
            
            if 'topics' not in user_data[post_data[pid]['user_id']]:
                user_data[post_data[pid]['user_id']]['topics']=[]
            user_data[post_data[pid]['user_id']]['topics']+=tags
            for t in tags: topic_to_user[t].add(post_data[pid]['user_id'])
            
    for tag in post_data[rid]['tags']:
        doc_freq[tag]+=1

In [86]:
tags_to_ignore = set([i[0] for i in sorted(doc_freq.items(), reverse=True, key=lambda x:x[1])[:1]])

In [87]:
for uid in chosen_uids:
    tags =[]
    for t in user_data[uid]['topics']: 
        if t not in tags_to_ignore: tags.append(t)

    tag_freq = defaultdict(int)
    for i in tags: tag_freq[i]+=1
    total_freq = sum(tag_freq.values())
    x = {i:[j/total_freq, j] for i,j in tag_freq.items()}
    user_data[uid]['top_conf'] = x
#     print(x)

Batches: 100%|███████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.65it/s]


In [14]:
test_data = []
gts=[]
for i in test_trees:
    gt=set()
    for pid in trees[i]:
        if post_data[pid]['user_id'] in chosen_uids: 
            gt.add(post_data[pid]['user_id'])
    test_data.append([i,gt]) 
    gts.append(len(gt))

In [15]:
get_stats(gts)

Mean: 15.798336798336798
Percentiles
5 - 8
25 - 11
50 - 15
75 - 19
90 - 24
99 - 37
99.5 - 41
99.7 - 41



In [16]:
def get_recall(gt,pred,K):
    return len( gt&pred )/ min(K,len(gt))

In [17]:
all_post_embeds = get_embeddings(all_posts, 'mpnet')

Batches: 100%|███████████████████████████████████████████████████████████████████████| 287/287 [02:15<00:00,  2.12it/s]


In [18]:
all_tag_embeds = get_embeddings(all_tag_posts, 'mpnet')

Batches: 100%|███████████████████████████████████████████████████████████████████████| 287/287 [00:26<00:00, 10.77it/s]


In [19]:
all_te_embeds = get_embeddings(all_te_posts, 'mpnet')

Batches: 100%|███████████████████████████████████████████████████████████████████████| 287/287 [02:20<00:00,  2.04it/s]


In [29]:
all_te_embeds.shape

(36624, 768)

In [39]:
random.seed(0)
rand_te = random.sample([i[0] for i in test_data],20)

In [24]:
def get_contributors(query, tr_embeds):
#     print(query)
#     print()
    qemb = get_embeddings(query, 'mpnet', progress=False)
    cosines = cos_sim(qemb, tr_embeds)[0]
    argort = (-cosines).argsort()
    done=set()
    users = []
    docs=[]
    for i in argort:
        if flat_uids[i] in done: continue
        else:
#             print(flat_uids[i],' - ', all_tag_posts[i])
#             print('-')
            users.append(flat_uids[i]); done.add(flat_uids[i])
#     input()
    return users

In [51]:
rec = []
# post_preds = []
K=[10,20,30]
for idx,[pid,gt] in enumerate(tqdm(test_data)):
#     if not post_data[pid]['tags']: continue
#     tags = ' - '.join(set(post_data[pid]['tags']))
#     text = post_data[pid]['title']+' '+post_data[pid]['text']
#     text = text# + ' ' + tags
#     pred = get_contributors(tags,all_tag_embeds)
# #     post_preds.append(pred)
    done=set()
    pred=[]
    for i,j in zip(te_preds[idx], post_preds[idx]):
        if j not in done: pred.append(j); done.add(j)
        if i not in done: pred.append(i); done.add(i)
        
    row=[]
    for k in K:
        row.append( len( gt&set(pred[:k]) )/ min(k,len(gt)) )
    rec.append(row)
#     print('-'*50)
rec=np.array(rec)
np.mean(rec[:,0]),np.mean(rec[:,1]),np.mean(rec[:,2])

100%|██████████████████████████████████████████████████████████████████████████████| 481/481 [00:00<00:00, 3844.11it/s]


(0.1527984027984028, 0.18286816440841205, 0.23752279718906136)

In [44]:
len(te_preds),len(post_preds)

(481, 481)

In [52]:
all_uids = [i for i in chosen_uids]
all_user_topics = [' - '.join(user_data[i]['topics']) for i in chosen_uids]

In [53]:
all_user_top_embs = get_embeddings(all_user_topics, 'mpnet')

Batches: 100%|███████████████████████████████████████████████████████████████████████████| 5/5 [00:02<00:00,  2.05it/s]


In [54]:
all_user_top_embs.shape

(552, 768)

In [51]:
def get_contributors_tt(query, tr_embeds, k=10, qtopics=None):
    print(query)
    print()
    qemb = get_embeddings(query, 'mpnet', False)

    cosines = cos_sim(qemb, tr_embeds)[0]
    argsort = (-cosines).argsort()
    for i in argsort[:k]:
        print(all_uids[i],all_user_topics[i])
        print('-')

    input()
    return set([all_uids[i] for i in argsort[:k]])

In [52]:
rec = []
K=10
for pid,gt in tqdm(test_data):
    if not post_data[pid]['tags']: continue
    tags = ' - '.join(post_data[pid]['tags'])
    pred = get_contributors_tt(tags,all_user_top_embs,K)
    rec.append( len( gt&pred )/ min(K,len(gt)) )
#     break
np.mean(rec)

  0%|                                                                                                   | 0/481 [00:00<?, ?it/s]

united-states - religion - canada - discrimination - conflict-of-interest

27691 israel - honors - crime - gender - climate-change - canada - united-states - donald-trump - discrimination - racism - armed-conflict - trade - international - guns - international-law - democracy - communism - european-union - russian-federation - trump-impeachment - election - germany - international-court - religion - gender-neutrality - presidential-election - political-transitions
-
8751 international-relations - south-korea - government - discrimination - brexit - china - policy - international-law - european-union - history - affordable-care-act - parties - campaign-finance - public-health - boris-johnson - public-safety - gender - north-america - vaccine - canada - debate - health-insurance - donald-trump - political-theory - article-50 - republican-party - north-korea - monarchy - privacy - nationalism - stimulus - coup - senate - immigration - healthcare - identity-document - united-kingdom - pres

  0%|                                                                                                   | 0/481 [11:10<?, ?it/s]


In [48]:
tags

'united-states - religion - canada - discrimination - conflict-of-interest'

In [56]:
import itertools

In [54]:
tops = all_user_topics[12].split(' - ')

In [None]:
a = list(itertools.permutations(tops))

In [55]:
all_topics = list(set(doc_freq.keys())-tags_to_ignore)
all_topics_emb = get_embeddings(all_topics, 'mpnet')

Batches: 100%|███████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 16.26it/s]


In [73]:
def get_contrib_usr_ind_tt(query, K, qtopics, gt):
    qtopics = [i for i in qtopics if i[0] not in topics_to_ignore]
    if not qtopics:
        qtopics
    total = 0
    argsorts = []
    df=[]
    for topic,conf in qtopics:
        total+=conf
        cosines = cos_sim(model.encode(topic), all_topics_emb)
        ags = (-cosines).argsort()[0].tolist()
        argsorts.append([ags[:30], cosines[0]])
    
    user_comb_score = defaultdict(float)
    for [topic,conf],[ag,coss] in zip(qtopics,argsorts):
        
        ind_rows = []
#         print(topic)
        for idx in ag:
            for u in topic_to_user[all_topics[idx]]:
                if u == 'AutoModerator': continue
                if all_topics[idx] not in user_data[u]['top_conf']: continue
                    
                u_scores = user_data[u]['top_conf'][all_topics[idx]]
                
                user_comb_score[u] += (conf/total) * u_scores[0] * (coss[idx]+1)/2
#                 ind_rows.append([u, all_topics[idx], round(u_scores[0],4), \
#                                  u_scores[1], round((conf/total) * u_scores[0] * coss[idx].item(),4)])
                ind_rows.append([u, all_topics[idx], round(u_scores[0]*100,4), \
                                 u_scores[1], round((conf/total) * u_scores[0] * (coss[idx].item()+1)/2,4)])
                
        ind_rows = sorted(ind_rows, key=lambda x:x[4], reverse=True)[:30]
        df.append([query, [topic,round(conf/total*100,4),conf], \
                   [ [all_topics[i],round( (coss[i].item()+1)/2,4)] for i in ag], ind_rows])
#         input()
#     print(user_comb_score)
    
    users,_ = zip(*sorted(user_comb_score.items(), key=lambda x:x[1], reverse=True))
    pred = set(users[:10])
    for idx,row in enumerate(df):
        inter = set([i[0] for i in row[-1]])&pred
        inter = sorted([[i, round(user_comb_score[i].item(),4)] for i in inter ],reverse=True,key=lambda x:x[1])
        df[idx] = df[idx] + [inter, gt]
    assert len(set(users[:30])-chosen_uids)==0
    return users, df

In [91]:
rec = []
df=[]
K=[10,20,30]
for pid,gt in tqdm(test_data):
    if pid not in rand_te: continue
    if not post_data[pid]['tags']: continue
    if not set(post_data[pid]['tags'])-tags_to_ignore: continue
    text = post_data[pid]['title']+' '+post_data[pid]['text']
    pred,df_ = get_contrib_usr_ind_tt(text,K,[[i,1] for i in post_data[pid]['tags']], gt)
    df+=df_
    row=[]
    for k in K:
        row.append( len( gt&set(pred[:k]) )/ min(k,len(gt)) )
    rec.append(row)
#     print('-'*50)
rec = np.array(rec)
np.mean(rec[:,0]),np.mean(rec[:,1]),np.mean(rec[:,2])

100%|████████████████████████████████████████████████████████████████████████████████| 481/481 [00:05<00:00, 82.13it/s]


(0.01, 0.03459401709401709, 0.07823382173382173)

In [95]:
pd.DataFrame(df).to_csv('reddit.csv')

In [96]:
df[0]

['Minimizing civilian casualties during the siege of Mariupol? What are the best practices of minimizing civilian casualties in cities under siege, as applied to the siege of Mariupol?\nI am looking for best practices from the perspective of world civilians, including civilians inside and outside of Mariupol. So please, no answers advocating WWIII to help Mariupol civilians.\nRealistic scenarios only, please!\nSEE ALSO:\nCould humanitarian aid be provided in Mariupol through the air? (refers only to help through the air, which is a subset of the current question)\nWhat is the purpose of the siege of the Ukrainian city of Mariupol by the Russian invaders? (some general info useful to answer the current question)\nWhat is the rationale of Russian troops not allowing civilians to evacuate from the encircled cities? (still more info)\nSiege of Mariupol  (Wikipedia page on the subject, lots of info)\nNOTES:\nAnswers supported by references are preferred, historical references are highly app

In [98]:
import torch

In [94]:
all_uids, user_embs, all_pids, all_tags, _ = \
pickle.load(open('/data/Projects/recommend_users_to_posts/so_tr_data.pkl', 'rb'))

In [96]:
all_uid_id = {j:i for i,j in enumerate(all_uids)}
all_tag_id = {j:i for i,j in enumerate(all_tags)}

In [105]:
model_wgts = torch.load('/data/Projects/recommend_users_to_posts/Models/model.pt')
user_embs = model_wgts['node_embeddings.weight']
tag_embs = model_wgts['node_embeddings.weight']

In [109]:
user_tag = cos_sim(tag_embs, user_embs)

In [112]:
torch.min(user_tag)

tensor(0.8714, device='cuda:0')

In [None]:
topic_to_user

In [102]:
def get_contrib_tag_network(qtopics,K):
    user_comb_score = defaultdict(float)
    for topic in qtopics:
        cosines = cos_sim(tag_embs[all_tag_id[topic]] ,user_embs)
        ags = (-cosines).argsort()[0].tolist()[:30]
        for idx in ags:
            user_comb_score[all_uids[idx]]+=cosines[0][idx].item()
            
    users,_ = zip(*sorted(user_comb_score.items(), key=lambda x:x[1], reverse=True))
    users = [int(i[1:]) for i in users]
    return set(users[:K])

In [107]:
rec = []
K=10
for pid,gt in tqdm(test_data):
    
    if not post_data[pid]['tags']: continue

    pred = get_contrib_tag_network(post_data[pid]['tags'], K)
    rec.append( len( gt&pred )/ min(K,len(gt)) )
#     print('-'*50)
np.mean(rec)

100%|███████████████████████████████████████████████████████████████████████████████| 481/481 [00:01<00:00, 404.58it/s]


0.015525690525690528