Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
RamonZhou committed Apr 26, 2024
1 parent 3afa105 commit 56ac909
Showing 1 changed file with 67 additions and 12 deletions.
79 changes: 67 additions & 12 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __getstate__(self):


def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False
g, gpb, nodes, node_types, fanout, edge_dir="in", prob=None, replace=False
):
"""Sample from local partition via graphbolt.
Expand All @@ -99,6 +99,8 @@ def _sample_neighbors_graphbolt(
The graph partition book.
nodes : tensor
The nodes to sample neighbors from.
node_types : tensor
The node type of each node.
fanout : tensor or int
The number of edges to be sampled for each node.
edge_dir : str, optional
Expand Down Expand Up @@ -128,6 +130,14 @@ def _sample_neighbors_graphbolt(
# Local partition may be saved in torch.int32 even though the global graph
# is in torch.int64.
nodes = nodes.to(dtype=g.indices.dtype)
if node_types is not None:
ntype_count = torch.bincount(
node_types, minlength=len(g.node_type_to_id)
)
seed_offsets = ntype_count.to(dtype=torch.int64).cumsum(0).tolist()
seed_offsets.insert(0, 0)
else:
seed_offsets = None

# 2. Perform sampling.
# [Rui][TODO] `prob` is not tested yet. Skip for now.
Expand All @@ -146,7 +156,7 @@ def _sample_neighbors_graphbolt(

return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(
nodes, None, fanout, replace=replace, return_eids=return_eids
nodes, seed_offsets, fanout, replace=replace, return_eids=return_eids
)

# 3. Map local node IDs to global node IDs.
Expand All @@ -157,6 +167,14 @@ def _sample_neighbors_graphbolt(
node_ids=subgraph.original_column_node_ids,
output_size=local_src.shape[0],
)
if subgraph.type_per_edge is None and g.type_per_edge is not None:
subgraph_type_per_edge = gb.expand_indptr(
subgraph.etype_offsets,
dtype=g.type_per_edge.dtype,
output_size=subgraph.indices.shape[0],
)
else:
subgraph_type_per_edge = subgraph.type_per_edge
global_nid_mapping = g.node_attributes[NID]
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
Expand All @@ -165,7 +183,7 @@ def _sample_neighbors_graphbolt(
if return_eids:
global_eids = g.edge_attributes[EID][subgraph.original_edge_ids]
return LocalSampledGraph(
global_src, global_dst, global_eids, subgraph.type_per_edge
global_src, global_dst, global_eids, subgraph_type_per_edge
)


Expand Down Expand Up @@ -241,6 +259,7 @@ def _sample_etype_neighbors_dgl(
local_g,
partition_book,
seed_nodes,
seed_ntypes,
fan_out,
edge_dir="in",
prob=None,
Expand All @@ -256,6 +275,8 @@ def _sample_etype_neighbors_dgl(
and edge IDs.
"""
assert etype_offset is not None, "The etype offset is not provided."
if seed_ntypes is not None:
seed_ntypes = None

local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
Expand Down Expand Up @@ -680,7 +701,7 @@ def merge_graphs(res_list, num_nodes):
)


def _distributed_access(g, nodes, issue_remote_req, local_access):
def _distributed_access(g, nodes, ntypes, issue_remote_req, local_access):
"""A routine that fetches local neighborhood of nodes from the distributed graph.
The local neighborhood of some nodes are stored in the local machine and the other
Expand All @@ -695,6 +716,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
The distributed graph
nodes : tensor
The nodes whose neighborhood are to be fetched.
ntypes : tensor
The node types of each node. It should be None if the graph is homogeneous.
issue_remote_req : callable
The function that issues requests to access remote data.
local_access : callable
Expand All @@ -713,13 +736,18 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
local_nids = None
for pid in range(partition_book.num_partitions()):
node_id = F.boolean_mask(nodes, partition_id == pid)
if ntypes is not None:
part_ntypes = F.boolean_mask(ntypes, partition_id == pid)
else:
part_ntypes = None
# We optimize the sampling on a local partition if the server and the client
# run on the same machine. With a good partitioning, most of the seed nodes
# should reside in the local partition. If the server and the client
# are not co-located, the client doesn't have a local partition.
if pid == partition_book.partid and g.local_partition is not None:
assert local_nids is None
local_nids = node_id
local_ntypes = part_ntypes
elif len(node_id) != 0:
req = issue_remote_req(node_id)
req_list.append((pid, req))
Expand All @@ -732,7 +760,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
res = local_access(g.local_partition, partition_book, local_nids)
res = local_access(
g.local_partition, partition_book, local_nids, local_ntypes
)
res_list.append(res)

# receive responses from remote machines.
Expand Down Expand Up @@ -896,6 +926,7 @@ def sample_etype_neighbors(
gpb = g.get_partition_book()
if isinstance(nodes, dict):
homo_nids = []
ntype_list = []
for ntype in nodes.keys():
assert (
ntype in g.ntypes
Expand All @@ -906,8 +937,15 @@ def sample_etype_neighbors(
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nid = gpb.map_to_homo_nid(typed_nodes, ntype)
homo_nids.append(homo_nid)
ntype_list.append(
torch.full(
(len(homo_nid),), g.get_ntype_id(ntype), dtype=F.int64
)
)
nodes = F.cat(homo_nids, 0)
ntypes = F.cat(ntype_list, 0)

def issue_remote_req(node_ids):
if prob is not None:
Expand All @@ -932,7 +970,7 @@ def issue_remote_req(node_ids):
use_graphbolt=use_graphbolt,
)

def local_access(local_g, partition_book, local_nids):
def local_access(local_g, partition_book, local_nids, local_ntypes):
etype_offset = gpb.local_etype_offset
# See NOTE 1
if prob is None:
Expand All @@ -949,6 +987,7 @@ def local_access(local_g, partition_book, local_nids):
local_g,
partition_book,
local_nids,
local_ntypes,
fanout,
edge_dir=edge_dir,
prob=_prob,
Expand All @@ -957,7 +996,9 @@ def local_access(local_g, partition_book, local_nids):
etype_sorted=etype_sorted,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(
g, nodes, ntypes, issue_remote_req, local_access
)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
Expand Down Expand Up @@ -1026,6 +1067,7 @@ def sample_neighbors(
gpb = g.get_partition_book()
if not gpb.is_homogeneous:
assert isinstance(nodes, dict)
ntype_list = []
homo_nids = []
for ntype in nodes:
assert (
Expand All @@ -1036,10 +1078,19 @@ def sample_neighbors(
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nid = gpb.map_to_homo_nid(typed_nodes, ntype)
homo_nids.append(homo_nid)
ntype_list.append(
torch.full(
(len(homo_nid),), g.get_ntype_id(ntype), dtype=F.int64
)
)
nodes = F.cat(homo_nids, 0)
ntype_offsets = F.cat(ntype_list, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
nodes = list(nodes.values())[0]
ntype_offsets = None

def issue_remote_req(node_ids):
if prob is not None:
Expand All @@ -1056,21 +1107,24 @@ def issue_remote_req(node_ids):
use_graphbolt=use_graphbolt,
)

def local_access(local_g, partition_book, local_nids):
def local_access(local_g, partition_book, local_nids, local_ntype_offsets):
# See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None
return _sample_neighbors(
use_graphbolt,
local_g,
partition_book,
local_nids,
local_ntype_offsets,
fanout,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(
g, nodes, ntype_offsets, issue_remote_req, local_access
)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
Expand Down Expand Up @@ -1214,10 +1268,11 @@ def in_subgraph(g, nodes):
def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids)

def local_access(local_g, partition_book, local_nids):
def local_access(local_g, partition_book, local_nids, local_ntypes):
assert local_ntypes is None, "local_ntypes should be None."
return _in_subgraph(local_g, partition_book, local_nids)

return _distributed_access(g, nodes, issue_remote_req, local_access)
return _distributed_access(g, nodes, None, issue_remote_req, local_access)


def _distributed_get_node_property(g, n, issue_remote_req, local_access):
Expand Down

0 comments on commit 56ac909

Please sign in to comment.