Skip to content

Commit

Permalink
[GraphBolt][CUDA] Pipelined sampling accuracy fix (#7088)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin authored Feb 5, 2024
1 parent 4ee0a8b commit a2e1c79
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 10 deletions.
27 changes: 19 additions & 8 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,33 @@ def _fetch_per_layer_impl(self, minibatch, stream):
with torch.cuda.stream(self.stream):
index = minibatch._seed_nodes
if isinstance(index, dict):
for idx in index.values():
idx.record_stream(torch.cuda.current_stream())
index = self.graph._convert_to_homogeneous_nodes(index)
else:
index.record_stream(torch.cuda.current_stream())

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)
return tensor

index, original_positions = index.sort()
if (original_positions.diff() == 1).all().item(): # is_sorted
if self.graph.node_type_offset is None:
# sorting not needed.
minibatch._subgraph_seed_nodes = None
else:
minibatch._subgraph_seed_nodes = original_positions
index.record_stream(torch.cuda.current_stream())
index, original_positions = index.sort()
if (original_positions.diff() == 1).all().item():
# already sorted.
minibatch._subgraph_seed_nodes = None
else:
minibatch._subgraph_seed_nodes = record_stream(
original_positions.sort()[1]
)
index_select_csc_with_indptr = partial(
torch.ops.graphbolt.index_select_csc, self.graph.csc_indptr
)

def record_stream(tensor):
if stream is not None and tensor.is_cuda:
tensor.record_stream(stream)

indptr, indices = index_select_csc_with_indptr(
self.graph.indices, index, None
)
Expand Down
8 changes: 6 additions & 2 deletions tests/python/pytorch/graphbolt/impl/test_neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def get_hetero_graph():
@unittest.skipIf(F._default_context_str != "gpu", reason="Enabled only on GPU.")
@pytest.mark.parametrize("hetero", [False, True])
@pytest.mark.parametrize("prob_name", [None, "weight", "mask"])
def test_NeighborSampler_GraphFetch(hetero, prob_name):
items = torch.arange(3)
@pytest.mark.parametrize("sorted", [False, True])
def test_NeighborSampler_GraphFetch(hetero, prob_name, sorted):
if sorted:
items = torch.arange(3)
else:
items = torch.tensor([2, 0, 1])
names = "seed_nodes"
itemset = gb.ItemSet(items, names=names)
graph = get_hetero_graph().to(F.ctx())
Expand Down

0 comments on commit a2e1c79

Please sign in to comment.