From 870d8d025bae7742aaab94eba150ba0d7d444800 Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Thu, 8 Feb 2024 09:56:31 +0800 Subject: [PATCH 1/3] [doc] fix undefined variable in code snippet (#7107) --- docs/source/guide/minibatch-node.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/guide/minibatch-node.rst b/docs/source/guide/minibatch-node.rst index af83463248d4..6e81895b8026 100644 --- a/docs/source/guide/minibatch-node.rst +++ b/docs/source/guide/minibatch-node.rst @@ -44,6 +44,7 @@ putting the list of generated MFGs onto GPU. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = gb.BuiltinDataset("ogbn-arxiv").load() + g = dataset.graph train_set = dataset.tasks[0].train_set datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. @@ -205,6 +206,7 @@ of node types to node IDs. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = gb.BuiltinDataset("ogbn-mag").load() + g = dataset.graph train_set = dataset.tasks[0].train_set datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. From 763bd39ff56a6e21473b077ae096d6acb79b4cba Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Thu, 8 Feb 2024 10:37:01 +0800 Subject: [PATCH 2/3] [DistGB] sample with graphbolt on homograph via DistDataLoader (#7098) --- tests/distributed/test_mp_dataloader.py | 118 +++++++++++++++--------- 1 file changed, 73 insertions(+), 45 deletions(-) diff --git a/tests/distributed/test_mp_dataloader.py b/tests/distributed/test_mp_dataloader.py index 4f8d80c9ddbf..cdefb28b93ad 100644 --- a/tests/distributed/test_mp_dataloader.py +++ b/tests/distributed/test_mp_dataloader.py @@ -22,10 +22,19 @@ class NeighborSampler(object): - def __init__(self, g, fanouts, sample_neighbors): + def __init__( + self, + g, + fanouts, + sample_neighbors, + use_graphbolt=False, + return_eids=False, + ): self.g = g self.fanouts = fanouts self.sample_neighbors = sample_neighbors + self.use_graphbolt = use_graphbolt + self.return_eids = return_eids def sample_blocks(self, seeds): import torch as th @@ -35,13 +44,16 @@ def sample_blocks(self, seeds): for fanout in self.fanouts: # For each seed node, sample ``fanout`` neighbors. frontier = self.sample_neighbors( - self.g, seeds, fanout, replace=True + self.g, seeds, fanout, use_graphbolt=self.use_graphbolt ) # Then we compact the frontier into a bipartite graph for # message passing. block = dgl.to_block(frontier, seeds) # Obtain the seed nodes for next layer. seeds = block.srcdata[dgl.NID] + if frontier.num_edges() > 0: + if not self.use_graphbolt or self.return_eids: + block.edata[dgl.EID] = frontier.edata[dgl.EID] blocks.insert(0, block) return blocks @@ -53,6 +65,7 @@ def start_server( part_config, disable_shared_mem, num_clients, + use_graphbolt=False, ): print("server: #clients=" + str(num_clients)) g = DistGraphServer( @@ -63,6 +76,7 @@ def start_server( part_config, disable_shared_mem=disable_shared_mem, graph_format=["csc", "coo"], + use_graphbolt=use_graphbolt, ) g.start() @@ -75,30 +89,36 @@ def start_dist_dataloader( drop_last, orig_nid, orig_eid, - group_id=0, + use_graphbolt=False, + return_eids=False, ): - import dgl - import torch as th - - os.environ["DGL_GROUP_ID"] = str(group_id) dgl.distributed.initialize(ip_config) gpb = None - disable_shared_mem = num_server > 0 + disable_shared_mem = num_server > 1 if disable_shared_mem: _, _, _, gpb, _, _, _ = load_partition(part_config, rank) num_nodes_to_sample = 202 batch_size = 32 train_nid = th.arange(num_nodes_to_sample) - dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config) - - for i in range(num_server): - part, _, _, _, _, _, _ = load_partition(part_config, i) + dist_graph = DistGraph( + "test_sampling", + gpb=gpb, + part_config=part_config, + use_graphbolt=use_graphbolt, + ) # Create sampler sampler = NeighborSampler( - dist_graph, [5, 10], dgl.distributed.sample_neighbors + dist_graph, + [5, 10], + dgl.distributed.sample_neighbors, + use_graphbolt=use_graphbolt, + return_eids=return_eids, ) + # Enable santity check in distributed sampling. + os.environ["DGL_DIST_DEBUG"] = "1" + # We need to test creating DistDataLoader multiple times. for i in range(2): # Create DataLoader for constructing blocks @@ -113,7 +133,7 @@ def start_dist_dataloader( groundtruth_g = CitationGraphDataset("cora")[0] max_nid = [] - for epoch in range(2): + for _ in range(2): for idx, blocks in zip( range(0, num_nodes_to_sample, batch_size), dataloader ): @@ -129,6 +149,16 @@ def start_dist_dataloader( src_nodes_id, dst_nodes_id ) assert np.all(F.asnumpy(has_edges)) + + if use_graphbolt and not return_eids: + continue + eids = orig_eid[block.edata[dgl.EID]] + expected_eids = groundtruth_g.edge_ids( + src_nodes_id, dst_nodes_id + ) + assert th.equal( + eids, expected_eids + ), f"{eids} != {expected_eids}" if drop_last: assert ( np.max(max_nid) @@ -311,23 +341,22 @@ def check_neg_dataloader(g, num_server, num_workers): assert p.exitcode == 0 -@unittest.skip(reason="Skip due to glitch in CI") -@pytest.mark.parametrize("num_server", [3]) +@pytest.mark.parametrize("num_server", [1]) @pytest.mark.parametrize("num_workers", [0, 4]) -@pytest.mark.parametrize("drop_last", [True, False]) -@pytest.mark.parametrize("num_groups", [1]) -def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): +@pytest.mark.parametrize("drop_last", [False, True]) +@pytest.mark.parametrize("use_graphbolt", [False, True]) +@pytest.mark.parametrize("return_eids", [False, True]) +def test_dist_dataloader( + num_server, num_workers, drop_last, use_graphbolt, return_eids +): reset_envs() - # No multiple partitions on single machine for - # multiple client groups in case of race condition. - if num_groups > 1: - num_server = 1 + os.environ["DGL_DIST_MODE"] = "distributed" + os.environ["DGL_NUM_SAMPLER"] = str(num_workers) with tempfile.TemporaryDirectory() as test_dir: ip_config = "ip_config.txt" generate_ip_config(ip_config, num_server, num_server) g = CitationGraphDataset("cora")[0] - print(g.idtype) num_parts = num_server num_hops = 1 @@ -339,6 +368,8 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): num_hops=num_hops, part_method="metis", return_mapping=True, + use_graphbolt=use_graphbolt, + store_eids=return_eids, ) part_config = os.path.join(test_dir, "test_sampling.json") @@ -353,36 +384,33 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups): part_config, num_server > 1, num_workers + 1, + use_graphbolt, ), ) p.start() time.sleep(1) pserver_list.append(p) - os.environ["DGL_DIST_MODE"] = "distributed" - os.environ["DGL_NUM_SAMPLER"] = str(num_workers) ptrainer_list = [] num_trainers = 1 for trainer_id in range(num_trainers): - for group_id in range(num_groups): - p = ctx.Process( - target=start_dist_dataloader, - args=( - trainer_id, - ip_config, - part_config, - num_server, - drop_last, - orig_nid, - orig_eid, - group_id, - ), - ) - p.start() - time.sleep( - 1 - ) # avoid race condition when instantiating DistGraph - ptrainer_list.append(p) + p = ctx.Process( + target=start_dist_dataloader, + args=( + trainer_id, + ip_config, + part_config, + num_server, + drop_last, + orig_nid, + orig_eid, + use_graphbolt, + return_eids, + ), + ) + p.start() + time.sleep(1) # avoid race condition when instantiating DistGraph + ptrainer_list.append(p) for p in ptrainer_list: p.join() From 7f7967b384578999f2e2e40d146b516c90649505 Mon Sep 17 00:00:00 2001 From: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:01:03 +0800 Subject: [PATCH 3/3] [doc] fix undefined variable in example --- docs/source/guide/minibatch-node.rst | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/guide/minibatch-node.rst b/docs/source/guide/minibatch-node.rst index 6e81895b8026..4fa695694a6d 100644 --- a/docs/source/guide/minibatch-node.rst +++ b/docs/source/guide/minibatch-node.rst @@ -45,6 +45,7 @@ putting the list of generated MFGs onto GPU. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = gb.BuiltinDataset("ogbn-arxiv").load() g = dataset.graph + feature = dataset.feature train_set = dataset.tasks[0].train_set datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers. @@ -207,6 +208,7 @@ of node types to node IDs. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = gb.BuiltinDataset("ogbn-mag").load() g = dataset.graph + feature = dataset.feature train_set = dataset.tasks[0].train_set datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True) datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.