In [34]:
import networkx as nx
import numpy as np
import pandas as pd
import random

In [21]:
def load_embeddings(filename):
    x = np.load(filename, allow_pickle = True)
    return x.item()

In [22]:
# expects npy file to be a dict
embeddings = load_embeddings('data/rolx_embeddings.npy')

In [23]:
embedding_dim = len(embeddings[0])

In [24]:
# Load graph into networkx (weighted, undirected)
def load_graph(filename):
    df = pd.read_csv(filename, header=None, names=['source', 'target', 'weight'])
    G = nx.from_pandas_edgelist(df, edge_attr='weight', create_using=nx.Graph())
    return G

In [25]:
G = load_graph('data/reddit_nodes_weighted_full.csv')

In [32]:
# generate positive examples of edges
def get_positive_examples(G, embeddings):
    pos_examples = []
    for edge in G.edges():
        src_embedding = embeddings[edge[0]]
        dst_embedding = embeddings[edge[1]]
        edge_vector = src_embedding + dst_embedding + [1] # label = 1
        pos_examples.append(edge_vector)
    return pos_examples

In [47]:
# generate negative examples
def get_negative_examples(G, embeddings, num_examples, attempts = 3000000, len_threshold = 5):
    node_list = list(G.nodes())
    neg_examples = []
    for i in range(attempts):
        if len(neg_examples) == num_examples:
            break
        rnd_node_pair = random.choices(node_list, k = 2)
        src = rnd_node_pair[0]
        dst = rnd_node_pair[1]
        if G.has_edge(src, dst):
            continue
        try:    
            path_length = nx.shortest_path_length(G, source=src, target=dst, weight = None)
        except nx.NetworkXNoPath:
            continue
        if(path_length) >= len_threshold:
            src_embedding = embeddings[src]
            dst_embedding = embeddings[dst]
            edge_vector = src_embedding + dst_embedding + [0] # label = 0
            neg_examples.append(edge_vector)
    return neg_examples

In [48]:
pos_examples = get_positive_examples(G, embeddings)
num_pos_examples = len(pos_examples)

In [49]:
neg_examples = get_negative_examples(G, embeddings, num_pos_examples)
num_neg_examples = len(neg_examples)

In [51]:
all_examples = pos_examples + neg_examples

In [52]:
# create dataframe from examples
cols = ['src' + str(i) for i in range(embedding_dim)] + ['dst' + str(i) for i in range(embedding_dim)] + ['label']
df = pd.DataFrame(all_examples, columns = cols) 

In [56]:
df.reset_index()
df.sample(10)

Unnamed: 0,src0,src1,src2,src3,src4,src5,src6,src7,src8,src9,...,dst87,dst88,dst89,dst90,dst91,dst92,dst93,dst94,dst95,label
112757,907.0,2183.0,18790.0,90321.0,113966.0,267898.0,166.084895,492.124587,3876.971334,19737.265711,...,17901700.0,33491632.0,88906120.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
516091,1.0,1.0,1.0,1.0,26.0,116.0,27.0,117.0,134.0,871.0,...,2869932.0,2950432.0,6959095.0,5547.0,23350.0,120147.0,345094.0,152922.0,289292.0,0
271631,13.0,16.0,38.0,158.0,6455.0,23842.0,501.384615,1857.076923,14369.307692,65516.692308,...,43584744.0,119098422.0,355610912.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
193160,139.0,166.0,2145.0,18473.0,60204.0,177162.0,462.985612,1539.151079,12078.503597,51898.352518,...,44151214.0,145149413.0,447730576.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
424778,1.0,1.0,1.0,1.0,3.0,3.0,4.0,4.0,4.0,4.0,...,2044886.0,2308379.0,5537900.0,3530.0,19896.0,110785.0,333575.0,152922.0,288501.0,0
244995,108.0,171.0,1572.0,14705.0,43666.0,136244.0,432.425926,1532.25,11707.222222,53162.518519,...,44294540.0,120648599.0,367510458.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
135787,108.0,349.0,1153.0,11489.0,51190.0,153979.0,494.333333,1635.259259,13129.759259,56678.222222,...,6121248.0,6438625.0,15240038.0,5547.0,23350.0,120147.0,345094.0,151131.0,298604.0,1
218331,116.0,239.0,928.0,4630.0,29114.0,85927.0,265.982759,818.517241,6388.75,27244.637931,...,18152087.0,33480033.0,89005952.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
2728,747.0,2960.0,14294.0,88205.0,110341.0,254933.0,184.982597,573.471218,4676.651941,23257.096386,...,44294540.0,145149413.0,447730576.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1
137511,409.0,1186.0,8467.0,50492.0,104159.0,262668.0,295.070905,886.224939,7406.017115,35823.836186,...,14651949.0,20079808.0,49732667.0,5547.0,23350.0,120147.0,345094.0,162133.0,351372.0,1


In [54]:
df.shape

(619334, 193)

In [58]:
df.to_csv('data/rolx_dataset.csv')