In [206]:
import pandas as pd
import networkx as nx
import json
from networkx.readwrite import json_graph
import numpy as np
import random
import sys
import os
from numba import jitclass, int32, float32, uint64
import scipy
from scipy.sparse import csr_matrix, load_npz, save_npz, spdiags

In [207]:
@jitclass([
    ('K', uint64),
    ('values', uint64[:]),
    ('q', float32[:]),
    ('J', uint64[:]),
])
class FastRandomChoiceCached(object):
    def __init__(self, values, probs):
        self.values = values
        self.K = np.uint64(probs.size)
        self.q = np.zeros(self.K, dtype=np.float32)
        self.J = np.zeros(self.K, dtype=np.uint64)
        self.prep_var_sample(probs)

    def prep_var_sample(self, probs):
        smaller, larger = [], []
        for kk, prob in enumerate(probs):
            self.q[kk] = self.K * prob
            if self.q[kk] < 1.0:
                smaller.append(kk)
            else:
                larger.append(kk)
        while len(smaller) > 0 and len(larger) > 0:
            small, large = smaller.pop(), larger.pop()
            self.J[small] = large
            self.q[large] = self.q[large] - (1.0 - self.q[small])
            if self.q[large] < 1.0:
                smaller.append(large)
            else:
                larger.append(large)

    def draw_one(self):
        kk = int(np.floor(np.random.rand() * len(self.J)))
        if np.random.rand() < self.q[kk]:
            return self.values[kk]
        else:
            return self.values[self.J[kk]]

    def sample(self, n, r1, r2):
        res = np.zeros(n, dtype=np.int32)
        lj = len(self.J)
        for i in range(n):
            kk = int(np.floor(r1[i] * lj))
            if r2[i] < self.q[kk]:
                res[i] = kk
            else:
                res[i] = self.J[kk]
        return res

    def draw_n(self, n):
        r1, r2 = np.random.rand(n), np.random.rand(n)
        return self.values[self.sample(n, r1, r2)]

In [208]:
def normalize(graph):
    norm_consts = graph.sum(axis=1)
    norm_consts = np.array(norm_consts).squeeze()
    norm = 1 / norm_consts
    norm[norm == np.inf] = 0
    graph = spdiags(norm, 0, len(norm_consts), len(norm_consts)) * graph
    return graph

In [209]:
prefix = './example_data/ppi'
G_data = json.load(open(prefix + "-G.json"))
G = json_graph.node_link_graph(G_data)
G_ours = nx.to_scipy_sparse_matrix(G)
G_ours = normalize(G_ours)

In [210]:
import tqdm
from scipy.sparse import csr_matrix, coo_matrix, lil_matrix

In [211]:
def run_random_walks(G, num_walks=1000, walk_len=2):
    pairs = lil_matrix(G.shape)
    
    random_cached_G = []
    for i in tqdm.tqdm(range(G.shape[0])):
        random_cached_G.append(FastRandomChoiceCached(
            G[i].indices.astype(np.uint64),
            G[i].data))
        
    for node in tqdm.tqdm(range(G.shape[0])):
        if len(G[node].indices) == 0:
            continue
        for i in range(num_walks):
            curr_node = node
            for j in range(walk_len):
                next_node = random_cached_G[curr_node].draw_one()
                # self co-occurrences are useless
                if curr_node != node:
                    pairs[node, curr_node] += 1
                curr_node = next_node
    return (pairs).tocoo()

In [212]:
pairs = run_random_walks(G_ours)

100%|██████████| 14755/14755 [00:02<00:00, 5389.41it/s]
100%|██████████| 14755/14755 [02:25<00:00, 101.18it/s]


In [214]:
scipy.sparse.save_npz('./example_data/ppi' + '-walks.npz',pairs)

In [215]:
!pwd

/Users/kolya/work/GraphSAGE-master 2


In [197]:
class batch_sampler:
    def __init__(self, pairs):
        self.shape = pairs.shape
        self.node_sampler = FastRandomChoiceCached(
            np.ravel_multi_index((pairs.row, pairs.col), dims=pairs.shape).astype(np.uint64),
            pairs.data / pairs.data.sum()
        )
    
    def sample_batch(self, batch_size=10):
        batch = self.node_sampler.draw_n(batch_size)
        return np.unravel_index(batch, self.shape)

In [201]:
bs = batch_sampler(pairs)

In [203]:
%%timeit
bs.sample_batch(10000)

727 µs ± 1.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
