Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sampler] Metapath sampler for metapath2vec #861

Merged
merged 11 commits into from Sep 26, 2019
34 changes: 21 additions & 13 deletions examples/pytorch/metapath2vec/download.py
@@ -1,38 +1,46 @@
import os
import torch as th
import torch.nn as nn
import tqdm


class AminerDataset:
class PBar(object):
def __enter__(self):
self.t = None
return self

def __call__(self, blockno, readsize, totalsize):
if self.t is None:
self.t = tqdm.tqdm(total=totalsize)
self.t.update(readsize)

def __exit__(self, exc_type, exc_value, traceback):
self.t.close()


class AminerDataset(object):
"""
Download Aminer Dataset from Amazon S3 bucket.
"""
def __init__(self, path):

self.url = 'https://s3.us-east-2.amazonaws.com/dgl.ai/dataset/aminer.zip'

if not os.path.exists(os.path.join(path, 'aminer')):
if not os.path.exists(os.path.join(path, 'aminer.txt')):
print('File not found. Downloading from', self.url)
self._download_and_extract(path, 'aminer.zip')
self.fn = os.path.join(path, 'aminer.txt')

def _download_and_extract(self, path, filename):
import shutil, zipfile, zlib
from tqdm import tqdm
import requests
import urllib.request

fn = os.path.join(path, filename)

if os.path.exists(path):
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
f_remote = requests.get(self.url, stream=True)
assert f_remote.status_code == 200, 'fail to open {}'.format(self.url)
with open(fn, 'wb') as writer:
for chunk in tqdm(f_remote.iter_content(chunk_size=1024*1024*3)):
writer.write(chunk)
with PBar() as pb:
urllib.request.urlretrieve(self.url, fn, pb)
print('Download finished. Unzipping the file...')

with zipfile.ZipFile(fn) as zf:
zf.extractall(path)
print('Unzip finished.')
self.fn = fn
184 changes: 60 additions & 124 deletions examples/pytorch/metapath2vec/sampler.py
@@ -1,167 +1,103 @@
import numpy as np
import torch
import torchvision
from torch.autograd import Variable
import random
import time
import tqdm
import dgl
import sys
import os

Metapath = "Conference-Paper-Author-Paper-Conference"
num_walks_per_node = 1000
walk_length = 100
path = sys.argv[1]

#construct mapping from text, could be changed to DGL later
def construct_id_dict():
id_to_paper = {}
id_to_author = {}
id_to_conf = {}
f_3 = open(".../id_author.txt", encoding="ISO-8859-1")
f_4 = open(".../id_conf.txt", encoding="ISO-8859-1")
f_5 = open(".../paper.txt", encoding="ISO-8859-1")
def construct_graph():
paper_ids = []
paper_names = []
author_ids = []
author_names = []
conf_ids = []
conf_names = []
f_3 = open(os.path.join(path, "id_author.txt"), encoding="ISO-8859-1")
f_4 = open(os.path.join(path, "id_conf.txt"), encoding="ISO-8859-1")
f_5 = open(os.path.join(path, "paper.txt"), encoding="ISO-8859-1")
while True:
z = f_3.readline()
if not z:
break
z = z.split('\t')
z = z.strip().split()
identity = int(z[0])
id_to_author[identity] = z[1].strip("\n")
author_ids.append(identity)
author_names.append(z[1])
while True:
w = f_4.readline()
if not w:
break;
w = w.split('\t')
w = w.strip().split()
identity = int(w[0])
id_to_conf[identity] = w[1].strip("\n")
conf_ids.append(identity)
conf_names.append(w[1])
while True:
v = f_5.readline()
if not v:
break;
v = v.split(' ')
v = v.strip().split()
identity = int(v[0])
paper_name = ""
for s in range(5, len(v)):
paper_name += v[s]
paper_name = 'p' + paper_name
id_to_paper[identity] = paper_name.strip('\n')
paper_name = 'p' + ''.join(v[1:])
paper_ids.append(identity)
paper_names.append(paper_name)
f_3.close()
f_4.close()
f_5.close()
return id_to_paper, id_to_author, id_to_conf

#construct mapping from text, could be changed to DGL later
def construct_types_mappings():
paper_to_author = {}
author_to_paper = {}
paper_to_conf = {}
conf_to_paper = {}
f_1 = open(".../paper_author.txt", "r")
f_2 = open(".../paper_conf.txt", "r")
author_ids_invmap = {x: i for i, x in enumerate(author_ids)}
conf_ids_invmap = {x: i for i, x in enumerate(conf_ids)}
paper_ids_invmap = {x: i for i, x in enumerate(paper_ids)}

paper_author_src = []
paper_author_dst = []
paper_conf_src = []
paper_conf_dst = []
f_1 = open(os.path.join(path, "paper_author.txt"), "r")
f_2 = open(os.path.join(path, "paper_conf.txt"), "r")
for x in f_1:
x = x.split('\t')
x[0] = int(x[0])
x[1] = int(x[1].strip('\n'))
if x[0] in paper_to_author:
paper_to_author[x[0]].append(x[1])
else:
paper_to_author[x[0]] = []
paper_to_author[x[0]].append(x[1])
if x[1] in author_to_paper:
author_to_paper[x[1]].append(x[0])
else:
author_to_paper[x[1]] = []
author_to_paper[x[1]].append(x[0])
paper_author_src.append(paper_ids_invmap[x[0]])
paper_author_dst.append(author_ids_invmap[x[1]])
for y in f_2:
y = y.split('\t')
y[0] = int(y[0])
y[1] = int(y[1].strip('\n'))
if y[0] in paper_to_conf:
paper_to_conf[y[0]].append(y[1])
else:
paper_to_conf[y[0]] = []
paper_to_conf[y[0]].append(y[1])
if y[1] in conf_to_paper:
conf_to_paper[y[1]].append(y[0])
else:
conf_to_paper[y[1]] = []
conf_to_paper[y[1]].append(y[0])
paper_conf_src.append(paper_ids_invmap[y[0]])
paper_conf_dst.append(conf_ids_invmap[y[1]])
f_1.close()
f_2.close()
return paper_to_author, author_to_paper, paper_to_conf, conf_to_paper

pa = dgl.bipartite((paper_author_src, paper_author_dst), 'paper', 'pa', 'author')
ap = dgl.bipartite((paper_author_dst, paper_author_src), 'author', 'ap', 'paper')
pc = dgl.bipartite((paper_conf_src, paper_conf_dst), 'paper', 'pc', 'conf')
cp = dgl.bipartite((paper_conf_dst, paper_conf_src), 'conf', 'cp', 'paper')
hg = dgl.hetero_from_relations([pa, ap, pc, cp])
return hg, author_names, conf_names, paper_names

#"conference - paper - Author - paper - conference" metapath sampling
def generate_metapath():
output_path = open(".../output_path.txt", "w")
id_to_paper, id_to_author, id_to_conf = construct_id_dict()
paper_to_author, author_to_paper, paper_to_conf, conf_to_paper = construct_types_mappings()
output_path = open(os.path.join(path, "output_path.txt"), "w")
count = 0
#loop all conferences
for conf_id in conf_to_paper.keys():
start_time = time.time()
print("sampling" + str(count))
conf = id_to_conf[conf_id]
conf0 = conf
#for each conference, simulate num_walks_per_node walks
for i in range(num_walks_per_node):
outline = conf0
# each walk with length walk_length
for j in range(walk_length):
# C - P
paper_list_1 = conf_to_paper[conf_id]
# check whether the paper nodes link to any author node
connections_1 = False
available_paper_1 = []
for k in range(len(paper_list_1)):
if paper_list_1[k] in paper_to_author:
available_paper_1.append(paper_list_1[k])
num_p_1 = len(available_paper_1)
if num_p_1 != 0:
connections_1 = True
paper_1_index = random.randrange(num_p_1)
#paper_id_1 = paper_list_1[paper_1_index]
paper_id_1 = available_paper_1[paper_1_index]
paper_1 = id_to_paper[paper_id_1]
outline += " " + paper_1
else:
break
# C - P - A
author_list = paper_to_author[paper_id_1]
num_a = len(author_list)
# No need to check
author_index = random.randrange(num_a)
author_id = author_list[author_index]
author = id_to_author[author_id]
outline += " " + author
# C - P - A - P
paper_list_2 = author_to_paper[author_id]
#check whether paper node links to any conference node
connections_2 = False
available_paper_2 = []
for m in range(len(paper_list_2)):
if paper_list_2[m] in paper_to_conf:
available_paper_2.append(paper_list_2[m])
num_p_2 = len(available_paper_2)
if num_p_2 != 0:
connections_2 = True
paper_2_index = random.randrange(num_p_2)
paper_id_2 = available_paper_2[paper_2_index]
paper_2 = id_to_paper[paper_id_2]
outline += " " + paper_2
else:
break
# C - P - A - P - C
conf_list = paper_to_conf[paper_id_2]
num_c = len(conf_list)
conf_index = random.randrange(num_c)
conf_id = conf_list[conf_index]
conf = id_to_conf[conf_id]
outline += " " + conf
if connections_1 and connections_2:
output_path.write(outline + "\n")
else:
break
# Note that the original mapping text has type indicator in front of each node just like "cVLDB"
# So the sampling sequence looks like "cconference ppaper aauthor ppaper cconference"
count += 1
print("--- %s seconds ---" % (time.time() - start_time))

hg, author_names, conf_names, paper_names = construct_graph()

for conf_idx in tqdm.trange(hg.number_of_nodes('conf')):
traces = dgl.contrib.sampling.metapath_random_walk(
hg, ['cp', 'pa', 'ap', 'pc'] * walk_length, [conf_idx], num_walks_per_node)
traces = traces[0]
for trace in traces:
tr = np.insert(trace.numpy(), 0, conf_idx)
outline = ' '.join(
(conf_names if i % 4 == 0 else author_names)[tr[i]]
for i in range(0, len(tr), 2)) # skip paper
print(outline, file=output_path)
output_path.close()


Expand Down
79 changes: 0 additions & 79 deletions include/dgl/sampler.h
Expand Up @@ -17,15 +17,6 @@ namespace dgl {

class ImmutableGraph;

struct RandomWalkTraces {
/*! \brief number of traces generated for each seed */
IdArray trace_counts;
/*! \brief length of each trace, concatenated */
IdArray trace_lengths;
/*! \brief the vertices, concatenated */
IdArray vertices;
};

class SamplerOp {
public:
/*!
Expand Down Expand Up @@ -65,76 +56,6 @@ class SamplerOp {
IdArray layer_sizes);
};

/*!
* \brief Batch-generate random walk traces
* \param seeds The array of starting vertex IDs
* \param num_traces The number of traces to generate for each seed
* \param num_hops The number of hops for each trace
* \return a flat ID array with shape (num_seeds, num_traces, num_hops + 1)
*/
IdArray RandomWalk(const GraphInterface *gptr,
IdArray seeds,
int num_traces,
int num_hops);

/*!
* \brief Batch-generate random walk traces with restart
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces RandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);

/*
* \brief Batch-generate random walk traces with restart on a bipartite graph, walking two
* hops at a time.
*
* Since it is walking on a bipartite graph, the vertices of a trace will always stay on the
* same side.
*
* Stop generating traces if max_frequrent_visited_nodes nodes are visited more than
* max_visit_counts times.
*
* \param seeds The array of starting vertex IDs
* \param restart_prob The restart probability
* \param visit_threshold_per_seed Stop generating more traces once the number of nodes
* visited for a seed exceeds this number. (Algorithm 1 in [1])
* \param max_visit_counts Alternatively, stop generating traces for a seed if no less
* than \c max_frequent_visited_nodes are visited no less than \c max_visit_counts
* times. (Algorithm 2 in [1])
* \param max_frequent_visited_nodes See \c max_visit_counts
* \return A RandomWalkTraces instance.
*
* \note Doesn't verify whether the graph is indeed a bipartite graph
*
* \sa [1] Eksombatchai et al., 2017 https://arxiv.org/abs/1711.07601
*/
RandomWalkTraces BipartiteSingleSidedRandomWalkWithRestart(
const GraphInterface *gptr,
IdArray seeds,
double restart_prob,
uint64_t visit_threshold_per_seed,
uint64_t max_visit_counts,
uint64_t max_frequent_visited_nodes);

} // namespace dgl

#endif // DGL_SAMPLER_H_