In [1]:
import networkx as nx

In [2]:
import pickle
from tqdm import tqdm
import math
import random
import numpy as np
from copy import deepcopy
from collections import defaultdict

In [3]:
trees, user_data, post_data, chosen_uids, chosen_trees = pickle.load(open('so_politics.pkl','rb'))

In [4]:
len(trees), len(chosen_trees), len(chosen_uids)

(15302, 1456, 552)

In [5]:
def checkAA(pid):
    if post_data[pid]['type'] in [1,3]: return False
    rid = post_data[pid]['rootid']
    if math.isnan(post_data[rid]['accepted_answer_id']): return False
    return post_data[rid]['accepted_answer_id']==pid

In [6]:
def remove_null(d1):
    d = deepcopy(d1)
    for k,v in list(d.items()):
        if not v:
            d[k]="N/A"
        elif k=='tags' and type(v)==list:
            d[k] = ' - '.join(v)
        elif type(v)==list:
#             print(k)
            del d[k]
        else:
            d[k]=str(v)
        
    return d

In [7]:
edges = []
max_len = 0
max_rid = -1
node_attr = defaultdict(dict)
user_tags = defaultdict(set)
user_posts = defaultdict(list)
all_tags=set()
for rid in tqdm(chosen_trees):
    
    cur_tree = trees[rid]
    tags = post_data[rid]['tags']
    all_tags.update(tags)
    tmp_len=0
    for pid in cur_tree:
        if post_data[pid]['user_id'] in chosen_uids:
            uid = 'u'+str(int(post_data[pid]['user_id']))
            user_posts[uid].append(post_data[pid]['text'])
    for edge in cur_tree.edges():
        p1 = post_data[edge[0]]
        p2 = post_data[edge[1]]
        if post_data[edge[0]]['user_id'] not in chosen_uids or \
            post_data[edge[1]]['user_id'] not in chosen_uids: continue
        t1 = post_data[edge[0]]['type']
        t2 = post_data[edge[1]]['type']
        u1 = 'u'+str(int(post_data[edge[0]]['user_id']))
        u2 = 'u'+str(int(post_data[edge[1]]['user_id']))
        edge = (str(edge[0]), str(edge[1]))
        
        
        user_tags[u1].update(tags)
        user_tags[u2].update(tags)
        
        if t1==1 and t2==2:
            node_attr[edge[1]]['AA'] = checkAA(int(edge[1]))
            edges.append(edge)#+({'type':'A'},))
            edges.append((u1,edge[0]))#,{'type':'OP'}))
            edges.append((u2,edge[1]))#,{'type':'Contributor'}))
        elif t1==1 and t2==3:
            edges.append(edge)#+({'type':'C'},))
            edges.append((u1,edge[0]))#,{'type':'OP'}))
            edges.append((u2,edge[1]))#,{'type':'Contributor'}))
        elif t1==2 and t2==3:
            node_attr[edge[0]]['AA'] = checkAA(int(edge[0]))
            edges.append(edge)#+({'type':'C'},))
            edges.append((u1,edge[0]))#,{'type':'Contributor'}))
            edges.append((u2,edge[1]))#,{'type':'Contributor'}))
        else:
            print(t1,t2)
        tmp_len+=1
        node_attr[u1] = user_data[p1['user_id']]
        node_attr[u2] = user_data[p2['user_id']]
        node_attr[edge[0]].update(p1)
        node_attr[edge[1]].update(p2)
    if tmp_len>max_len:
        max_len = tmp_len
        max_rid=rid

100%|████████████████████████████████████████████████████████████████████████████| 1456/1456 [00:00<00:00, 2239.25it/s]


In [8]:
G = nx.DiGraph(edges)

In [9]:
nx.set_node_attributes(G, node_attr)

In [10]:
G_und = G.to_undirected()

In [11]:
def get_nei_nodes(nid, dep=3):
    queue = []
    visited = set()
    queue.append([nid,0])
    visited.add(nid)
    
    nids = []
    while queue:
        src = queue.pop(0)
        nids.append(src[0])
        if src[1]<dep:
            for sibid in nx.all_neighbors(G_und, src[0]):
                if sibid not in visited:
                    queue.append([sibid, src[1]+1])
                    visited.add(sibid)
    return nids

In [12]:
for i in random.sample(G.nodes(),5):
    sub_nodes=get_nei_nodes(i,3)
    sub_G = G.subgraph(sub_nodes)
    print(len(sub_G), len(sub_G.edges()))

844 1026
2050 2688
152 184
354 468
1135 1386


In [13]:

nx.write_gexf(G.subgraph(sub_nodes), "so_politics.gexf")

In [15]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2', {'cache_dir':'/data/huggingface_cache'}).to('cuda')

2023-05-06 03:36:21.516442: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/usr/local/lib:/usr/lib:/usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/usr/local/lib:/usr/lib:
2023-05-06 03:36:21.516615: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/cuda/lib:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/opt/amazon/efa/lib:/opt/amazon/openmpi/lib:/usr/local/lib:/usr/lib:/usr/local/cuda/lib:/usr/local/cuda/

Using Maximum Sequence Length:  384


In [16]:
all_uids = [i for i in node_attr.keys() if i[0]=='u']
all_uid_id = {j:i for i,j in enumerate(all_uids)}
all_pids = [i for i in node_attr.keys() if i[0]!='u']
all_pid_id = {j:i for i,j in enumerate(all_pids)}
all_tags = list(all_tags)
all_tag_id = {j:i for i,j in enumerate(all_tags)}

In [29]:
all_posts = [post_data[int(i) if i[0]!='c' else i]['text'] for i in all_pids]

In [29]:
user_embs = np.array([model.encode(user_posts[i]).mean(0) for i in tqdm(all_uids)])

100%|████████████████████████████████████████████████████████████████████████████████| 552/552 [09:08<00:00,  1.01it/s]


In [33]:
post_embs = model.encode(all_posts, show_progress_bar=True, batch_size=128)

Batches: 100%|█████████████████████████| 339/339 [05:37<00:00,  1.00it/s]


In [31]:
training_samples = []
for uid,u_tags in user_tags.items():    
    for u_tag in u_tags:
        row = []
        row.append([ all_uid_id[uid], all_tag_id[u_tag]])
        negs = random.sample(set(all_tags)-u_tags,10)
        for n in negs:
            row.append([all_uid_id[uid], all_tag_id[n]])
        training_samples.append(row)
    

In [39]:
len(training_samples)

29866

In [34]:
pickle.dump([all_uids, user_emb, all_pids, post_embs], open('./so_net_data.pkl','wb'))

In [118]:
all_user_post_edges = set()
all_user_user_edges = set()
user_to_posts = defaultdict(set)
user_to_users = defaultdict(set)
for u,v,d in G.edges(data=True):
    if u[0]=='u':
        all_user_post_edges.add(tuple(sorted((u,v))))
        user_to_posts[u].add(v)
        
        for nid in nx.all_neighbors(G, v):
            if nid[0]!='u':
                for uid in nx.all_neighbors(G, nid):
                    if uid[0]=='u' and uid!=u:
                        all_user_user_edges.add(tuple(sorted((u,uid))))
                        user_to_users[u].add(uid)

In [123]:
for a,b in all_user_post_edges:
    assert nx.shortest_path_length(G_und,a,b)==1
for a,b in all_user_user_edges:
    assert nx.shortest_path_length(G_und,a,b)==3

In [112]:
nx.shortest_path_length?

In [124]:
len(all_user_post_edges), len(all_user_user_edges)

(43284, 12232)

In [125]:
user_user_samples = []
for uid, p_uids in user_to_users.items():
    for p_uid in p_uids:
        row=[]
        row.append([all_uid_id[uid], all_uid_id[p_uid]])
        n_uids = random.sample(set(all_uids)-p_uids,10)
        row += [[all_uid_id[uid], all_uid_id[nuid]] for nuid in n_uids] 
        user_user_samples.append(row)

In [126]:
user_post_samples = []
for uid, p_uids in tqdm(user_to_posts.items()):
#     print(len(p_uids))
    for p_uid in p_uids:
        row=[]
        row.append([all_uid_id[uid], all_pid_id[p_uid]])
        n_uids = random.sample(set(random.sample(all_pids,len(p_uids)+10))-p_uids,5)
        row += [[all_uid_id[uid], all_pid_id[nuid]] for nuid in n_uids]
        user_post_samples.append(row)

100%|████████████████████████████████████████████████████████████████████████████████| 552/552 [00:09<00:00, 59.23it/s]


In [127]:
len(user_post_samples), len(user_user_samples)

(43284, 24464)

In [138]:
67.07*0.22*0.16

2.360864

In [128]:
pickle.dump([all_uids, user_emb, all_pids, post_embs, user_user_samples, user_post_samples],\
            open('./so_tr_data.pkl','wb'))

In [83]:
import torch
from sentence_transformers.util import cos_sim

In [129]:
model_wgts = torch.load('./Models/model.pt')

In [131]:
train_user_embs = model_wgts['node_embeddings.weight']
train_post_embs = model_wgts['word_embeddings.weight']

In [133]:
train_user_embs.shape, user_emb.shape, post_embs.shape, train_post_embs.shape

(torch.Size([552, 768]), (552, 768), (43284, 768), torch.Size([43284, 768]))

In [134]:
user_post_cos = cos_sim(user_emb, post_embs)
tr_user_post_cos = cos_sim(train_user_embs, train_post_embs)

In [135]:
jac1=[]
jac2=[]
for uid, p_pids in tqdm(user_to_posts.items()):
    uid = all_uid_id[uid]
    p_pids = set([all_pid_id[i] for i in p_pids])
    pred1 = set((-user_post_cos[uid]).argsort()[:10].tolist())
    pred2 = set((-tr_user_post_cos[uid]).argsort()[:10].tolist())
    
    jac1.append(len(p_pids&pred1)/len(p_pids|pred1))
    jac2.append(len(p_pids&pred2)/len(p_pids|pred2))
    
np.mean(jac1), np.mean(jac2)

100%|███████████████████████████████████████████████████████████████████████████████| 552/552 [00:03<00:00, 165.61it/s]


(0.023813022739192925, 0.020648614608151174)

In [136]:
user_user_cos = cos_sim(user_emb, user_emb)
tr_user_user_cos = cos_sim(train_user_embs, train_user_embs)

In [137]:
jac1=[]
jac2=[]
for uid, p_uids in tqdm(user_to_users.items()):
    uid = all_uid_id[uid]
    p_uids = set([all_uid_id[i] for i in p_uids])
    pred1 = set((-user_user_cos[uid]).argsort()[:10].tolist())
    pred2 = set((-tr_user_user_cos[uid]).argsort()[:10].tolist())
    
    jac1.append(len(p_uids&pred1)/len(p_uids|pred1))
    jac2.append(len(p_uids&pred2)/len(p_uids|pred2))
    
np.mean(jac1), np.mean(jac2)

100%|███████████████████████████████████████████████████████████████████████████████| 552/552 [00:00<00:00, 561.29it/s]


(0.06063541316893815, 0.03193194002743345)