Skip to content

Commit

Permalink
[Refactor] Sampler code refactor (#454)
Browse files Browse the repository at this point in the history
* refactored sampler code

* docstring

* fix tutorial
  • Loading branch information
BarclayII authored and jermainewang committed Mar 18, 2019
1 parent eb1acec commit 681e521
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 130 deletions.
298 changes: 170 additions & 128 deletions python/dgl/contrib/sampling/sampler.py
Expand Up @@ -3,7 +3,7 @@
import sys
import numpy as np
import threading
import random
from numbers import Integral
import traceback

from ..._ffi.function import _init_api
Expand All @@ -19,84 +19,26 @@

__all__ = ['NeighborSampler', 'LayerSampler']

class SampledSubgraphLoader(object):
def __init__(self, g, batch_size, sampler,
expand_factor=None, num_hops=1, layer_sizes=None,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, add_self_loop=False):
self._g = g
if not g._graph.is_readonly():
raise NotImplementedError("NodeFlow loader only support read-only graphs.")
self._batch_size = batch_size
class NodeFlowSamplerIter(object):
def __init__(self, sampler):
super(NodeFlowSamplerIter, self).__init__()
self._sampler = sampler
if sampler == 'neighbor':
self._expand_factor = expand_factor
self._num_hops = num_hops
elif sampler == 'layer':
self._layer_sizes = utils.toindex(layer_sizes)
else:
raise NotImplementedError('Invalid sampler option: "%s"' % sampler)
self._node_prob = node_prob
if node_prob is not None:
raise NotImplementedError('Non-uniform sampling is currently not supported.')
self._add_self_loop = add_self_loop
if self._node_prob is not None:
assert self._node_prob.shape[0] == g.number_of_nodes(), \
"We need to know the sampling probability of every node"
if seed_nodes is None:
self._seed_nodes = F.arange(0, g.number_of_nodes())
else:
self._seed_nodes = seed_nodes
if shuffle:
self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._seed_nodes = utils.toindex(self._seed_nodes)
self._num_workers = num_workers
self._neighbor_type = neighbor_type
self._nflows = []
self._seed_ids = []
self._nflow_idx = 0

def _prefetch(self):
if self._sampler == 'neighbor':
handles = unwrap_to_ptr_list(_CAPI_UniformSampling(
self._g._graph._handle,
self._seed_nodes.todgltensor(),
int(self._nflow_idx), # start batch id
int(self._batch_size), # batch size
int(self._num_workers), # num batches
int(self._expand_factor),
int(self._num_hops),
self._neighbor_type,
self._add_self_loop))
elif self._sampler == 'layer':
handles = unwrap_to_ptr_list(_CAPI_LayerSampling(
self._g._graph._handle,
self._seed_nodes.todgltensor(),
int(self._nflow_idx), # start batch id
int(self._batch_size), # batch size
int(self._num_workers), # num batches
self._layer_sizes.todgltensor(),
self._neighbor_type))
else:
raise NotImplementedError('Invalid sampler option: "%s"' % self._sampler)
nflows = [NodeFlow(self._g, hdl) for hdl in handles]
def prefetch(self):
nflows = self._sampler.fetch(self._nflow_idx)
self._nflows.extend(nflows)
self._nflow_idx += len(nflows)

def __iter__(self):
return self

def __next__(self):
# If we don't have prefetched NodeFlows, let's prefetch them.
if len(self._nflows) == 0:
self._prefetch()
# At this point, if we still don't have NodeFlows, we must have
# iterate all NodeFlows and we should stop the iterator now.
self.prefetch()
if len(self._nflows) == 0:
raise StopIteration
return self._nflows.pop(0)

class _Prefetcher(object):
class PrefetchingWrapper(object):
"""Internal shared prefetcher logic. It can be sub-classed by a Thread-based implementation
or Process-based implementation."""
_dataq = None # Data queue transmits prefetched elements
Expand All @@ -105,17 +47,17 @@ class _Prefetcher(object):

_checked_start = False # True once startup has been checkd by _check_start

def __init__(self, loader, num_prefetch):
super(_Prefetcher, self).__init__()
self.loader = loader
def __init__(self, sampler_iter, num_prefetch):
super(PrefetchingWrapper, self).__init__()
self.sampler_iter = sampler_iter
assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.'
self.num_prefetch = num_prefetch

def run(self):
"""Method representing the process’s activity."""
"""Method representing the process activity."""
# Startup - Master waits for this
try:
loader_iter = iter(self.loader)
loader_iter = self.sampler_iter
self._errorq.put(None)
except Exception as e: # pylint: disable=broad-except
tb = traceback.format_exc()
Expand Down Expand Up @@ -174,45 +116,97 @@ def _check_start(self):
def next(self):
return self.__next__()


class _ThreadPrefetcher(_Prefetcher, threading.Thread):
class ThreadPrefetchingWrapper(PrefetchingWrapper, threading.Thread):
"""Internal threaded prefetcher."""

def __init__(self, *args, **kwargs):
super(_ThreadPrefetcher, self).__init__(*args, **kwargs)
super(ThreadPrefetchingWrapper, self).__init__(*args, **kwargs)
self._dataq = queue.Queue(self.num_prefetch)
self._controlq = queue.Queue()
self._errorq = queue.Queue(self.num_prefetch)
self.daemon = True
self.start()
self._check_start()

class _PrefetchingLoader(object):
"""Prefetcher for a Loader in a separate Thread or Process.
This iterator will create another thread or process to perform
``iter_next`` and then store the data in memory. It potentially accelerates
the data read, at the cost of more memory usage.

Parameters
----------
loader : an iterator
Source loader.
num_prefetch : int, default 1
Number of elements to prefetch from the loader. Must be greater 0.
"""
class NodeFlowSampler(object):
'''
Base class that generates NodeFlows from a graph.
Class properties
----------------
immutable_only : bool
Whether the sampler only works on immutable graphs.
Subclasses can override this property.
'''
immutable_only = False

def __init__(
self,
g,
batch_size,
seed_nodes,
shuffle,
num_prefetch,
prefetching_wrapper_class):
self._g = g
if self.immutable_only and not g._graph.is_readonly():
raise NotImplementedError("This loader only support read-only graphs.")

def __init__(self, loader, num_prefetch=1):
self._loader = loader
self._num_prefetch = num_prefetch
if num_prefetch < 1:
raise ValueError('num_prefetch must be greater 0.')
self._batch_size = batch_size

if seed_nodes is None:
self._seed_nodes = F.arange(0, g.number_of_nodes())
else:
self._seed_nodes = seed_nodes
if shuffle:
self._seed_nodes = F.rand_shuffle(self._seed_nodes)
self._seed_nodes = utils.toindex(self._seed_nodes)

if num_prefetch:
self._prefetching_wrapper_class = prefetching_wrapper_class
self._num_prefetch = num_prefetch

def fetch(self, current_nodeflow_index):
'''
Method that returns the next "bunch" of NodeFlows.
Each worker will return a single NodeFlow constructed from a single
batch.
Subclasses of NodeFlowSampler should override this method.
Parameters
----------
current_nodeflow_index : int
How many NodeFlows the sampler has generated so far.
Returns
-------
list[NodeFlow]
Next "bunch" of nodeflows to be processed.
'''
raise NotImplementedError

def __iter__(self):
return _ThreadPrefetcher(self._loader, self._num_prefetch)
it = NodeFlowSamplerIter(self)
if self._num_prefetch:
return self._prefetching_wrapper_class(it, self._num_prefetch)
else:
return it

@property
def g(self):
return self._g

def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, prefetch=False, add_self_loop=False):
@property
def seed_nodes(self):
return self._seed_nodes

@property
def batch_size(self):
return self._batch_size

class NeighborSampler(NodeFlowSampler):
'''Create a sampler that samples neighborhood.
It returns a generator of :class:`~dgl.NodeFlow`. This can be viewed as
Expand Down Expand Up @@ -283,26 +277,52 @@ def NeighborSampler(g, batch_size, expand_factor, num_hops=1,
add_self_loop : bool, optional
If true, add self loop to the sampled NodeFlow.
The edge IDs of the self loop edges are -1. Default: False
Returns
-------
generator
The generator of NodeFlows.
'''
loader = SampledSubgraphLoader(g, batch_size, 'neighbor',
expand_factor=expand_factor, num_hops=num_hops,
neighbor_type=neighbor_type, node_prob=node_prob,
seed_nodes=seed_nodes, shuffle=shuffle,
num_workers=num_workers,
add_self_loop=add_self_loop)
if not prefetch:
return loader
else:
return _PrefetchingLoader(loader, num_prefetch=num_workers*2)

def LayerSampler(g, batch_size, layer_sizes,
neighbor_type='in', node_prob=None, seed_nodes=None,
shuffle=False, num_workers=1, prefetch=False):

immutable_only = True

def __init__(
self,
g,
batch_size,
expand_factor=None,
num_hops=1,
neighbor_type='in',
node_prob=None,
seed_nodes=None,
shuffle=False,
num_workers=1,
prefetch=False,
add_self_loop=False):
super(NeighborSampler, self).__init__(
g, batch_size, seed_nodes, shuffle, num_workers * 2,
ThreadPrefetchingWrapper)

assert node_prob is None, 'non-uniform node probability not supported'
assert isinstance(expand_factor, Integral), 'non-int expand_factor not supported'

self._expand_factor = expand_factor
self._num_hops = num_hops
self._add_self_loop = add_self_loop
self._num_workers = num_workers
self._neighbor_type = neighbor_type

def fetch(self, current_nodeflow_index):
handles = unwrap_to_ptr_list(_CAPI_UniformSampling(
self.g.c_handle,
self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id
self.batch_size, # batch size
self._num_workers, # num batches
self._expand_factor,
self._num_hops,
self._neighbor_type,
self._add_self_loop))
nflows = [NodeFlow(self.g, hdl) for hdl in handles]
return nflows


class LayerSampler(NodeFlowSampler):
'''Create a sampler that samples neighborhood.
This creates a NodeFlow loader that samples subgraphs from the input graph
Expand All @@ -325,20 +345,42 @@ def LayerSampler(g, batch_size, layer_sizes,
num_workers: the number of worker threads that sample NodeFlows in parallel.
prefetch : bool, default False
Whether to prefetch the samples in the next batch.
Returns
-------
A NodeFlow iterator
The iterator returns a list of batched NodeFlows.
'''
loader = SampledSubgraphLoader(g, batch_size, 'layer', layer_sizes=layer_sizes,
neighbor_type=neighbor_type, node_prob=node_prob,
seed_nodes=seed_nodes, shuffle=shuffle,
num_workers=num_workers)
if not prefetch:
return loader
else:
return _PrefetchingLoader(loader, num_prefetch=num_workers*2)

immutable_only = True

def __init__(
self,
g,
batch_size,
layer_sizes,
neighbor_type='in',
node_prob=None,
seed_nodes=None,
shuffle=False,
num_workers=1,
prefetch=False):
super(LayerSampler, self).__init__(
g, batch_size, seed_nodes, shuffle, num_workers * 2,
ThreadPrefetchingWrapper)

assert node_prob is None, 'non-uniform node probability not supported'

self._num_workers = num_workers
self._neighbor_type = neighbor_type
self._layer_sizes = utils.toindex(layer_sizes)

def fetch(self, current_nodeflow_index):
handles = unwrap_to_ptr_list(_CAPI_LayerSampling(
self.g.c_handle,
self.seed_nodes.todgltensor(),
current_nodeflow_index, # start batch id
self.batch_size, # batch size
self._num_workers, # num batches
self._layer_sizes.todgltensor(),
self._neighbor_type))
nflows = [NodeFlow(self.g, hdl) for hdl in handles]
return nflows

def create_full_nodeflow(g, num_layers, add_self_loop=False):
"""Convert a full graph to NodeFlow to run a L-layer GNN model.
Expand All @@ -362,6 +404,6 @@ def create_full_nodeflow(g, num_layers, add_self_loop=False):
expand_factor = g.number_of_nodes()
sampler = NeighborSampler(g, batch_size, expand_factor,
num_layers, add_self_loop=add_self_loop)
return next(sampler)
return next(iter(sampler))

_init_api('dgl.sampling', __name__)
5 changes: 5 additions & 0 deletions python/dgl/graph.py
Expand Up @@ -38,6 +38,11 @@ class DGLBaseGraph(object):
def __init__(self, graph):
self._graph = graph

@property
def c_handle(self):
"""The C handle for the graph."""
return self._graph._handle

def number_of_nodes(self):
"""Return the number of nodes in the graph.
Expand Down

0 comments on commit 681e521

Please sign in to comment.