From 56ac909102e4e05d8a1b5068a681d9b72817abfc Mon Sep 17 00:00:00 2001 From: Ramon Zhou Date: Fri, 26 Apr 2024 08:47:18 +0000 Subject: [PATCH] init --- python/dgl/distributed/graph_services.py | 79 ++++++++++++++++++++---- 1 file changed, 67 insertions(+), 12 deletions(-) diff --git a/python/dgl/distributed/graph_services.py b/python/dgl/distributed/graph_services.py index 5590167d3099..2103733ec5ad 100644 --- a/python/dgl/distributed/graph_services.py +++ b/python/dgl/distributed/graph_services.py @@ -82,7 +82,7 @@ def __getstate__(self): def _sample_neighbors_graphbolt( - g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False + g, gpb, nodes, node_types, fanout, edge_dir="in", prob=None, replace=False ): """Sample from local partition via graphbolt. @@ -99,6 +99,8 @@ def _sample_neighbors_graphbolt( The graph partition book. nodes : tensor The nodes to sample neighbors from. + node_types : tensor + The node type of each node. fanout : tensor or int The number of edges to be sampled for each node. edge_dir : str, optional @@ -128,6 +130,14 @@ def _sample_neighbors_graphbolt( # Local partition may be saved in torch.int32 even though the global graph # is in torch.int64. nodes = nodes.to(dtype=g.indices.dtype) + if node_types is not None: + ntype_count = torch.bincount( + node_types, minlength=len(g.node_type_to_id) + ) + seed_offsets = ntype_count.to(dtype=torch.int64).cumsum(0).tolist() + seed_offsets.insert(0, 0) + else: + seed_offsets = None # 2. Perform sampling. # [Rui][TODO] `prob` is not tested yet. Skip for now. @@ -146,7 +156,7 @@ def _sample_neighbors_graphbolt( return_eids = g.edge_attributes is not None and EID in g.edge_attributes subgraph = g._sample_neighbors( - nodes, None, fanout, replace=replace, return_eids=return_eids + nodes, seed_offsets, fanout, replace=replace, return_eids=return_eids ) # 3. Map local node IDs to global node IDs. @@ -157,6 +167,14 @@ def _sample_neighbors_graphbolt( node_ids=subgraph.original_column_node_ids, output_size=local_src.shape[0], ) + if subgraph.type_per_edge is None and g.type_per_edge is not None: + subgraph_type_per_edge = gb.expand_indptr( + subgraph.etype_offsets, + dtype=g.type_per_edge.dtype, + output_size=subgraph.indices.shape[0], + ) + else: + subgraph_type_per_edge = subgraph.type_per_edge global_nid_mapping = g.node_attributes[NID] global_src = global_nid_mapping[local_src] global_dst = global_nid_mapping[local_dst] @@ -165,7 +183,7 @@ def _sample_neighbors_graphbolt( if return_eids: global_eids = g.edge_attributes[EID][subgraph.original_edge_ids] return LocalSampledGraph( - global_src, global_dst, global_eids, subgraph.type_per_edge + global_src, global_dst, global_eids, subgraph_type_per_edge ) @@ -241,6 +259,7 @@ def _sample_etype_neighbors_dgl( local_g, partition_book, seed_nodes, + seed_ntypes, fan_out, edge_dir="in", prob=None, @@ -256,6 +275,8 @@ def _sample_etype_neighbors_dgl( and edge IDs. """ assert etype_offset is not None, "The etype offset is not provided." + if seed_ntypes is not None: + seed_ntypes = None local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid) local_ids = F.astype(local_ids, local_g.idtype) @@ -680,7 +701,7 @@ def merge_graphs(res_list, num_nodes): ) -def _distributed_access(g, nodes, issue_remote_req, local_access): +def _distributed_access(g, nodes, ntypes, issue_remote_req, local_access): """A routine that fetches local neighborhood of nodes from the distributed graph. The local neighborhood of some nodes are stored in the local machine and the other @@ -695,6 +716,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): The distributed graph nodes : tensor The nodes whose neighborhood are to be fetched. + ntypes : tensor + The node types of each node. It should be None if the graph is homogeneous. issue_remote_req : callable The function that issues requests to access remote data. local_access : callable @@ -713,6 +736,10 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): local_nids = None for pid in range(partition_book.num_partitions()): node_id = F.boolean_mask(nodes, partition_id == pid) + if ntypes is not None: + part_ntypes = F.boolean_mask(ntypes, partition_id == pid) + else: + part_ntypes = None # We optimize the sampling on a local partition if the server and the client # run on the same machine. With a good partitioning, most of the seed nodes # should reside in the local partition. If the server and the client @@ -720,6 +747,7 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): if pid == partition_book.partid and g.local_partition is not None: assert local_nids is None local_nids = node_id + local_ntypes = part_ntypes elif len(node_id) != 0: req = issue_remote_req(node_id) req_list.append((pid, req)) @@ -732,7 +760,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access): # sample neighbors for the nodes in the local partition. res_list = [] if local_nids is not None: - res = local_access(g.local_partition, partition_book, local_nids) + res = local_access( + g.local_partition, partition_book, local_nids, local_ntypes + ) res_list.append(res) # receive responses from remote machines. @@ -896,6 +926,7 @@ def sample_etype_neighbors( gpb = g.get_partition_book() if isinstance(nodes, dict): homo_nids = [] + ntype_list = [] for ntype in nodes.keys(): assert ( ntype in g.ntypes @@ -906,8 +937,15 @@ def sample_etype_neighbors( typed_nodes = nodes[ntype] else: typed_nodes = toindex(nodes[ntype]).tousertensor() - homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype)) + homo_nid = gpb.map_to_homo_nid(typed_nodes, ntype) + homo_nids.append(homo_nid) + ntype_list.append( + torch.full( + (len(homo_nid),), g.get_ntype_id(ntype), dtype=F.int64 + ) + ) nodes = F.cat(homo_nids, 0) + ntypes = F.cat(ntype_list, 0) def issue_remote_req(node_ids): if prob is not None: @@ -932,7 +970,7 @@ def issue_remote_req(node_ids): use_graphbolt=use_graphbolt, ) - def local_access(local_g, partition_book, local_nids): + def local_access(local_g, partition_book, local_nids, local_ntypes): etype_offset = gpb.local_etype_offset # See NOTE 1 if prob is None: @@ -949,6 +987,7 @@ def local_access(local_g, partition_book, local_nids): local_g, partition_book, local_nids, + local_ntypes, fanout, edge_dir=edge_dir, prob=_prob, @@ -957,7 +996,9 @@ def local_access(local_g, partition_book, local_nids): etype_sorted=etype_sorted, ) - frontier = _distributed_access(g, nodes, issue_remote_req, local_access) + frontier = _distributed_access( + g, nodes, ntypes, issue_remote_req, local_access + ) if not gpb.is_homogeneous: return _frontier_to_heterogeneous_graph(g, frontier, gpb) else: @@ -1026,6 +1067,7 @@ def sample_neighbors( gpb = g.get_partition_book() if not gpb.is_homogeneous: assert isinstance(nodes, dict) + ntype_list = [] homo_nids = [] for ntype in nodes: assert ( @@ -1036,10 +1078,19 @@ def sample_neighbors( else: typed_nodes = toindex(nodes[ntype]).tousertensor() homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype)) + homo_nid = gpb.map_to_homo_nid(typed_nodes, ntype) + homo_nids.append(homo_nid) + ntype_list.append( + torch.full( + (len(homo_nid),), g.get_ntype_id(ntype), dtype=F.int64 + ) + ) nodes = F.cat(homo_nids, 0) + ntype_offsets = F.cat(ntype_list, 0) elif isinstance(nodes, dict): assert len(nodes) == 1 nodes = list(nodes.values())[0] + ntype_offsets = None def issue_remote_req(node_ids): if prob is not None: @@ -1056,7 +1107,7 @@ def issue_remote_req(node_ids): use_graphbolt=use_graphbolt, ) - def local_access(local_g, partition_book, local_nids): + def local_access(local_g, partition_book, local_nids, local_ntype_offsets): # See NOTE 1 _prob = [g.edata[prob].local_partition] if prob is not None else None return _sample_neighbors( @@ -1064,13 +1115,16 @@ def local_access(local_g, partition_book, local_nids): local_g, partition_book, local_nids, + local_ntype_offsets, fanout, edge_dir=edge_dir, prob=_prob, replace=replace, ) - frontier = _distributed_access(g, nodes, issue_remote_req, local_access) + frontier = _distributed_access( + g, nodes, ntype_offsets, issue_remote_req, local_access + ) if not gpb.is_homogeneous: return _frontier_to_heterogeneous_graph(g, frontier, gpb) else: @@ -1214,10 +1268,11 @@ def in_subgraph(g, nodes): def issue_remote_req(node_ids): return InSubgraphRequest(node_ids) - def local_access(local_g, partition_book, local_nids): + def local_access(local_g, partition_book, local_nids, local_ntypes): + assert local_ntypes is None, "local_ntypes should be None." return _in_subgraph(local_g, partition_book, local_nids) - return _distributed_access(g, nodes, issue_remote_req, local_access) + return _distributed_access(g, nodes, None, issue_remote_req, local_access) def _distributed_get_node_property(g, n, issue_remote_req, local_access):