Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE][Dist] Apply optimizations to DistDGL #7359

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions python/dgl/distributed/graph_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __getstate__(self):


def _sample_neighbors_graphbolt(
g, gpb, nodes, fanout, edge_dir="in", prob=None, replace=False
g, gpb, nodes, node_types, fanout, edge_dir="in", prob=None, replace=False
):
"""Sample from local partition via graphbolt.

Expand All @@ -99,6 +99,8 @@ def _sample_neighbors_graphbolt(
The graph partition book.
nodes : tensor
The nodes to sample neighbors from.
node_types : tensor
The node type of each node.
fanout : tensor or int
The number of edges to be sampled for each node.
edge_dir : str, optional
Expand Down Expand Up @@ -128,6 +130,14 @@ def _sample_neighbors_graphbolt(
# Local partition may be saved in torch.int32 even though the global graph
# is in torch.int64.
nodes = nodes.to(dtype=g.indices.dtype)
if node_types is not None:
ntype_count = torch.bincount(
node_types, minlength=len(g.node_type_to_id)
)
seed_offsets = ntype_count.to(dtype=torch.int64).cumsum(0).tolist()
seed_offsets.insert(0, 0)
else:
seed_offsets = None

# 2. Perform sampling.
# [Rui][TODO] `prob` is not tested yet. Skip for now.
Expand All @@ -146,7 +156,7 @@ def _sample_neighbors_graphbolt(

return_eids = g.edge_attributes is not None and EID in g.edge_attributes
subgraph = g._sample_neighbors(
nodes, None, fanout, replace=replace, return_eids=return_eids
nodes, seed_offsets, fanout, replace=replace, return_eids=return_eids
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are nodes required to be sorted in node types no matter seed_offsets is None or not?

)

# 3. Map local node IDs to global node IDs.
Expand All @@ -157,6 +167,14 @@ def _sample_neighbors_graphbolt(
node_ids=subgraph.original_column_node_ids,
output_size=local_src.shape[0],
)
if subgraph.type_per_edge is None and g.type_per_edge is not None:
subgraph_type_per_edge = gb.expand_indptr(
subgraph.etype_offsets,
dtype=g.type_per_edge.dtype,
output_size=subgraph.indices.shape[0],
)
else:
subgraph_type_per_edge = subgraph.type_per_edge
global_nid_mapping = g.node_attributes[NID]
global_src = global_nid_mapping[local_src]
global_dst = global_nid_mapping[local_dst]
Expand All @@ -165,7 +183,7 @@ def _sample_neighbors_graphbolt(
if return_eids:
global_eids = g.edge_attributes[EID][subgraph.original_edge_ids]
return LocalSampledGraph(
global_src, global_dst, global_eids, subgraph.type_per_edge
global_src, global_dst, global_eids, subgraph_type_per_edge
)


Expand Down Expand Up @@ -241,6 +259,7 @@ def _sample_etype_neighbors_dgl(
local_g,
partition_book,
seed_nodes,
seed_ntypes,
fan_out,
edge_dir="in",
prob=None,
Expand All @@ -256,6 +275,8 @@ def _sample_etype_neighbors_dgl(
and edge IDs.
"""
assert etype_offset is not None, "The etype offset is not provided."
if seed_ntypes is not None:
seed_ntypes = None

local_ids = partition_book.nid2localnid(seed_nodes, partition_book.partid)
local_ids = F.astype(local_ids, local_g.idtype)
Expand Down Expand Up @@ -680,7 +701,7 @@ def merge_graphs(res_list, num_nodes):
)


def _distributed_access(g, nodes, issue_remote_req, local_access):
def _distributed_access(g, nodes, ntypes, issue_remote_req, local_access):
"""A routine that fetches local neighborhood of nodes from the distributed graph.

The local neighborhood of some nodes are stored in the local machine and the other
Expand All @@ -695,6 +716,8 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
The distributed graph
nodes : tensor
The nodes whose neighborhood are to be fetched.
ntypes : tensor
The node types of each node. It should be None if the graph is homogeneous.
issue_remote_req : callable
The function that issues requests to access remote data.
local_access : callable
Expand All @@ -713,13 +736,18 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
local_nids = None
for pid in range(partition_book.num_partitions()):
node_id = F.boolean_mask(nodes, partition_id == pid)
if ntypes is not None:
part_ntypes = F.boolean_mask(ntypes, partition_id == pid)
else:
part_ntypes = None
# We optimize the sampling on a local partition if the server and the client
# run on the same machine. With a good partitioning, most of the seed nodes
# should reside in the local partition. If the server and the client
# are not co-located, the client doesn't have a local partition.
if pid == partition_book.partid and g.local_partition is not None:
assert local_nids is None
local_nids = node_id
local_ntypes = part_ntypes
elif len(node_id) != 0:
req = issue_remote_req(node_id)
req_list.append((pid, req))
Expand All @@ -732,7 +760,9 @@ def _distributed_access(g, nodes, issue_remote_req, local_access):
# sample neighbors for the nodes in the local partition.
res_list = []
if local_nids is not None:
res = local_access(g.local_partition, partition_book, local_nids)
res = local_access(
g.local_partition, partition_book, local_nids, local_ntypes
)
res_list.append(res)

# receive responses from remote machines.
Expand Down Expand Up @@ -896,6 +926,7 @@ def sample_etype_neighbors(
gpb = g.get_partition_book()
if isinstance(nodes, dict):
homo_nids = []
ntype_list = []
for ntype in nodes.keys():
assert (
ntype in g.ntypes
Expand All @@ -906,8 +937,15 @@ def sample_etype_neighbors(
typed_nodes = nodes[ntype]
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nid = gpb.map_to_homo_nid(typed_nodes, ntype)
homo_nids.append(homo_nid)
ntype_list.append(
torch.full(
(len(homo_nid),), g.get_ntype_id(ntype), dtype=F.int64
)
)
nodes = F.cat(homo_nids, 0)
ntypes = F.cat(ntype_list, 0)

def issue_remote_req(node_ids):
if prob is not None:
Expand All @@ -932,7 +970,7 @@ def issue_remote_req(node_ids):
use_graphbolt=use_graphbolt,
)

def local_access(local_g, partition_book, local_nids):
def local_access(local_g, partition_book, local_nids, local_ntypes):
etype_offset = gpb.local_etype_offset
# See NOTE 1
if prob is None:
Expand All @@ -949,6 +987,7 @@ def local_access(local_g, partition_book, local_nids):
local_g,
partition_book,
local_nids,
local_ntypes,
fanout,
edge_dir=edge_dir,
prob=_prob,
Expand All @@ -957,7 +996,9 @@ def local_access(local_g, partition_book, local_nids):
etype_sorted=etype_sorted,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(
g, nodes, ntypes, issue_remote_req, local_access
)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
Expand Down Expand Up @@ -1026,6 +1067,7 @@ def sample_neighbors(
gpb = g.get_partition_book()
if not gpb.is_homogeneous:
assert isinstance(nodes, dict)
ntype_list = []
homo_nids = []
for ntype in nodes:
assert (
Expand All @@ -1036,10 +1078,19 @@ def sample_neighbors(
else:
typed_nodes = toindex(nodes[ntype]).tousertensor()
homo_nids.append(gpb.map_to_homo_nid(typed_nodes, ntype))
homo_nid = gpb.map_to_homo_nid(typed_nodes, ntype)
homo_nids.append(homo_nid)
ntype_list.append(
torch.full(
(len(homo_nid),), g.get_ntype_id(ntype), dtype=F.int64
)
)
nodes = F.cat(homo_nids, 0)
ntype_offsets = F.cat(ntype_list, 0)
elif isinstance(nodes, dict):
assert len(nodes) == 1
nodes = list(nodes.values())[0]
ntype_offsets = None

def issue_remote_req(node_ids):
if prob is not None:
Expand All @@ -1056,21 +1107,24 @@ def issue_remote_req(node_ids):
use_graphbolt=use_graphbolt,
)

def local_access(local_g, partition_book, local_nids):
def local_access(local_g, partition_book, local_nids, local_ntype_offsets):
# See NOTE 1
_prob = [g.edata[prob].local_partition] if prob is not None else None
return _sample_neighbors(
use_graphbolt,
local_g,
partition_book,
local_nids,
local_ntype_offsets,
fanout,
edge_dir=edge_dir,
prob=_prob,
replace=replace,
)

frontier = _distributed_access(g, nodes, issue_remote_req, local_access)
frontier = _distributed_access(
g, nodes, ntype_offsets, issue_remote_req, local_access
)
if not gpb.is_homogeneous:
return _frontier_to_heterogeneous_graph(g, frontier, gpb)
else:
Expand Down Expand Up @@ -1214,10 +1268,11 @@ def in_subgraph(g, nodes):
def issue_remote_req(node_ids):
return InSubgraphRequest(node_ids)

def local_access(local_g, partition_book, local_nids):
def local_access(local_g, partition_book, local_nids, local_ntypes):
assert local_ntypes is None, "local_ntypes should be None."
return _in_subgraph(local_g, partition_book, local_nids)

return _distributed_access(g, nodes, issue_remote_req, local_access)
return _distributed_access(g, nodes, None, issue_remote_req, local_access)


def _distributed_get_node_property(g, n, issue_remote_req, local_access):
Expand Down
Loading