diff --git a/python/dgl/subgraph.py b/python/dgl/subgraph.py index eed2a3fe38a0..5752e5a36045 100644 --- a/python/dgl/subgraph.py +++ b/python/dgl/subgraph.py @@ -20,6 +20,7 @@ "out_subgraph", "khop_in_subgraph", "khop_out_subgraph", + "khop_subgraph", ] @@ -984,6 +985,159 @@ def khop_out_subgraph( DGLGraph.khop_out_subgraph = utils.alias_func(khop_out_subgraph) +def khop_subgraph( + graph, nodes, k, *,fanout=None, relabel_nodes=True, store_ids=True, output_device=None +): + """Return the subgraph induced by k-hop neighborhood of the specified node(s) by treating directed edges as undireted while hopping. + + We can expand a set of nodes by including the successors and predecessor of them. From a + specified node set, a k-hop subgraph is obtained by first repeating the node set + expansion for k times and then creating a node induced subgraph. In addition to + extracting the subgraph, DGL also copies the features of the extracted nodes and + edges to the resulting graph. The copy is *lazy* and incurs data movement only + when needed. We can control how many nodes to include using fanout. + + If the graph is heterogeneous, DGL extracts a subgraph per relation and composes + them as the resulting graph. Thus the resulting graph has the same set of relations + as the input one. + + Parameters + ---------- + graph : DGLGraph + The input graph. + nodes : nodes or dict[str, nodes] + The starting node(s) to expand, which cannot have any duplicate value. The result + will be undefined otherwise. The allowed formats are: + + * Int: ID of a single node. + * Int Tensor: Each element is a node ID. The tensor must have the same device + type and ID data type as the graph's. + * iterable[int]: Each element is a node ID. + + If the graph is homogeneous, one can directly pass the above formats. + Otherwise, the argument must be a dictionary with keys being node types + and values being the node IDs in the above formats. + k : int + The number of hops. + fanout: int, optinal + The number of successor and predeccesors each include when expanding. If None, include all + relabel_nodes : bool, optional + If True, it will remove the isolated nodes and relabel the rest nodes in the + extracted subgraph. + store_ids : bool, optional + If True, it will store the raw IDs of the extracted edges in the ``edata`` of the + resulting graph under name ``dgl.EID``; if ``relabel_nodes`` is ``True``, it will + also store the raw IDs of the extracted nodes in the ``ndata`` of the resulting + graph under name ``dgl.NID``. + output_device : Framework-specific device context object, optional + The output device. Default is the same as the input graph. + + Returns + ------- + DGLGraph + The subgraph. + Tensor or dict[str, Tensor], optional + The new IDs of the input :attr:`nodes` after node relabeling. This is returned + only when :attr:`relabel_nodes` is True. It is in the same form as :attr:`nodes`. + + """ + import numpy as np + import torch + if graph.is_block: + raise DGLError("Extracting subgraph of a block graph is not allowed.") + + is_mapping = isinstance(nodes, Mapping) + if not is_mapping: + assert ( + len(graph.ntypes) == 1 + ), "need a dict of node type and IDs for graph with multiple node types" + nodes = {graph.ntypes[0]: nodes} + + for nty, nty_nodes in nodes.items(): + nodes[nty] = utils.prepare_tensor( + graph, nty_nodes, 'nodes["{}"]'.format(nty) + ) + + last_hop_nodes = nodes + k_hop_nodes_ = [last_hop_nodes] + device = context_of(nodes) + place_holder = F.copy_to(F.tensor([], dtype=graph.idtype), device) + for _ in range(k): + current_hop_nodes = {nty: [] for nty in graph.ntypes} + # add outgoing nbrs + for cetype in graph.canonical_etypes: + srctype, _, dsttype = cetype + _, out_nbrs = graph.out_edges( + last_hop_nodes.get(srctype, place_holder), etype=cetype + ) + if fanout is not None and fanout