In [1]:
import networkx as nx
import numpy as np
import pandas as pd
import random

In [2]:
def load_embeddings(filename):
    x = np.load(filename, allow_pickle = True)
    return x.item()

In [3]:
# expects npy file to be a dict
embeddings = load_embeddings('data/rolx_embeddings.npy')

In [4]:
embedding_dim = len(embeddings[0])

In [39]:
def get_weights_dict(filename):
    weights = pd.read_csv(filename, header = None)
    weights.columns = ['src', 'dst', 'weight']
    
    weights_dict = {}
    for i in range(weights.shape[0]):
        src = weights.iloc[i, 0]
        dst = weights.iloc[i, 1]
        weight = weights.iloc[i, 2]

        weights_dict[(src, dst)] = weight
        weights_dict[(dst, src)] = weight
    return weights_dict

In [40]:
weights_dict = get_weights_dict('data/reddit_nodes_weighted_full.csv')

In [41]:
# Load graph into networkx (weighted, undirected)
def load_graph(filename):
    df = pd.read_csv(filename, header=None, names=['source', 'target', 'weight'])
    G = nx.from_pandas_edgelist(df, edge_attr='weight', create_using=nx.Graph())
    return G

In [42]:
G = load_graph('data/reddit_nodes_weighted_full.csv')

In [43]:
# generate positive examples of edges
def get_positive_examples(G, embeddings, weights_dict):
    pos_examples = []
    for edge in G.edges():
        src_embedding = embeddings[edge[0]]
        dst_embedding = embeddings[edge[1]]
        edge_vector = src_embedding + dst_embedding + [weights_dict[(edge[0], edge[1])]] # label = edge weight
        pos_examples.append(edge_vector)
    return pos_examples

In [44]:
# generate negative examples
def get_negative_examples(G, embeddings, num_examples, attempts = 3000000, len_threshold = 5):
    node_list = list(G.nodes())
    neg_examples = []
    edges_used = set()
    for i in range(attempts):
        if len(neg_examples) == num_examples:
            break
        rnd_node_pair = random.choices(node_list, k = 2)
        src = rnd_node_pair[0]
        dst = rnd_node_pair[1]
        if G.has_edge(src, dst):
            continue
        try:    
            path_length = nx.shortest_path_length(G, source=src, target=dst, weight = None)
        except nx.NetworkXNoPath:
            continue
        if(path_length) >= len_threshold:
            src_embedding = embeddings[src]
            dst_embedding = embeddings[dst]
            edge_vector = src_embedding + dst_embedding + [0] # label = 0
            neg_examples.append(edge_vector)
            edges_used.add((src, dst))
    return neg_examples, edges_used

In [46]:
pos_examples = get_positive_examples(G, embeddings, weights_dict)
num_pos_examples = len(pos_examples)
print(num_pos_examples)

309667


In [47]:
neg_examples, edges_used = get_negative_examples(G, embeddings, num_pos_examples)
num_neg_examples = len(neg_examples)
print(num_neg_examples)

309667


In [48]:
all_examples = pos_examples + neg_examples

In [49]:
# create train/test dataframe from examples
cols = ['src' + str(i) for i in range(embedding_dim)] + ['dst' + str(i) for i in range(embedding_dim)] + ['label']
df = pd.DataFrame(all_examples, columns = cols) 

In [50]:
df.reset_index()
df.sample(10)

Unnamed: 0,src0,src1,src2,src3,src4,src5,src6,src7,src8,src9,...,dst87,dst88,dst89,dst90,dst91,dst92,dst93,dst94,dst95,label
251177,169.0,422.0,2265.0,14677.0,47982.0,148158.0,309.721893,1047.869822,7730.236686,34057.16568,...,41103539.0,124978835.0,355628556.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
93657,568.0,2476.0,21527.0,115882.0,123287.0,265353.0,291.853873,870.846831,7480.308099,40529.573944,...,35257864.0,72151194.0,202259803.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
140180,303.0,862.0,8142.0,56938.0,92805.0,236708.0,359.029703,1154.19802,9633.636964,46793.09571,...,44294540.0,145149413.0,447730576.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,5
600900,1.0,1.0,1.0,1.0,14.0,23.0,15.0,24.0,55.0,386.0,...,49867.0,262186.0,823556.0,385.0,1139.0,2513.0,21137.0,59030.0,179815.0,0
523772,1.0,1.0,1.0,1.0,11.0,17.0,12.0,18.0,21.0,48.0,...,903589.0,736977.0,1635373.0,4601.0,19896.0,110785.0,333575.0,145591.0,276937.0,0
537636,1.0,1.0,1.0,1.0,2.0,2.0,3.0,3.0,4.0,4.0,...,13689995.0,21902733.0,55504297.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,0
349449,1.0,2.0,1.0,2.0,9.0,12.0,10.0,14.0,35.0,96.0,...,712281.0,978443.0,2477203.0,3530.0,19896.0,110785.0,333575.0,152995.0,325429.0,0
134344,336.0,705.0,6872.0,53983.0,89733.0,229609.0,306.967262,1002.589286,8093.583333,39118.345238,...,13143764.0,18768708.0,46986463.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
460366,1.0,1.0,1.0,1.0,1290.0,5504.0,1291.0,5505.0,36959.0,163781.0,...,29484.0,135527.0,404155.0,346.0,392.0,2911.0,18897.0,71260.0,191596.0,0
509123,1.0,1.0,1.0,1.0,2.0,2.0,3.0,3.0,3.0,3.0,...,13822499.0,18679247.0,45958336.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,0


In [51]:
df.shape

(619334, 193)

In [52]:
df.to_csv('data/rolx_dataset_weighted.csv')

In [59]:
# generate inference examples
def get_inference_examples(G, embeddings, edges_used, num_examples = 500000, attempts = 1000000):
    node_list = list(G.nodes())
    inference_examples = []
    for i in range(attempts):
        if len(inference_examples) == num_examples:
            break
        rnd_node_pair = random.choices(node_list, k = 2)
        src = rnd_node_pair[0]
        dst = rnd_node_pair[1]
        if G.has_edge(src, dst):
            continue
        edge_tuple = (src, dst)
        if edge_tuple not in edges_used:
            src_embedding = embeddings[src]
            dst_embedding = embeddings[dst]
            edge_vector = [src, dst] + src_embedding + dst_embedding
            inference_examples.append(edge_vector)
    return inference_examples

In [60]:
inference_examples = get_inference_examples(G, embeddings, edges_used)
print(len(inference_examples))

500000


In [61]:
# create inference dataframe from examples
cols = ['src_id', 'dst_id'] + ['src' + str(i) for i in range(embedding_dim)] + ['dst' + str(i) for i in range(embedding_dim)]
inference_df = pd.DataFrame(inference_examples, columns = cols) 

In [62]:
inference_df.sample(10)

Unnamed: 0,src_id,dst_id,src0,src1,src2,src3,src4,src5,src6,src7,...,dst86,dst87,dst88,dst89,dst90,dst91,dst92,dst93,dst94,dst95
323975,83297,41861,3.0,3.0,3.0,3.0,293.0,657.0,98.666667,220.0,...,4788753.0,26752923.0,53116392.0,144675320.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0
453416,72121,73628,9.0,10.0,9.0,10.0,245.0,403.0,28.222222,45.888889,...,191809.0,831058.0,1327229.0,3494486.0,4601.0,11182.0,102665.0,310542.0,145456.0,280330.0
148504,75993,65471,3.0,3.0,4.0,4.0,452.0,1033.0,152.333333,346.0,...,1024626.0,3749750.0,3665822.0,8444421.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0
411709,80680,64296,24.0,50.0,47.0,92.0,1009.0,2551.0,44.958333,111.875,...,462460.0,1726392.0,1503194.0,3419057.0,3401.0,15890.0,108629.0,322218.0,148508.0,276937.0
308771,59252,87907,2.0,2.0,3.0,4.0,292.0,658.0,148.0,332.0,...,4317736.0,22500045.0,41117101.0,107877370.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0
85190,64214,43009,5.0,8.0,9.0,13.0,239.0,399.0,50.4,83.4,...,296034.0,1034772.0,692518.0,1351169.0,5547.0,23350.0,120147.0,345094.0,146038.0,313436.0
393738,52398,38986,1.0,1.0,1.0,1.0,1625.0,5381.0,1626.0,5382.0,...,770771.0,2922791.0,2801801.0,6380880.0,3401.0,15890.0,108629.0,322218.0,162133.0,318744.0
331663,40528,83334,46.0,50.0,90.0,318.0,7341.0,26596.0,162.5,590.913043,...,3394036.0,17605236.0,26585671.0,68093473.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0
475650,81625,56100,2.0,2.0,2.0,2.0,39.0,47.0,20.5,24.5,...,315662.0,1122534.0,1167524.0,2921497.0,5547.0,23350.0,120147.0,345094.0,151131.0,297063.0
323707,9564,69237,1.0,1.0,1.0,1.0,1434.0,4267.0,1435.0,4268.0,...,1433048.0,5539074.0,6601731.0,16987740.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0


In [63]:
inference_df.shape

(500000, 194)

In [64]:
inference_df.to_csv('data/rolx_inference_weighted.csv')