Skip to content

Commit

Permalink
Merge branch 'ondisk_dataset' of https://github.com/drivanov/dgl into…
Browse files Browse the repository at this point in the history
… ondisk_dataset
  • Loading branch information
drivanov committed Feb 8, 2024
2 parents 9548212 + 0283dcc commit b8fc709
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 45 deletions.
4 changes: 4 additions & 0 deletions docs/source/guide/minibatch-node.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ putting the list of generated MFGs onto GPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
Expand Down Expand Up @@ -205,6 +207,8 @@ of node types to node IDs.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-mag").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
Expand Down
118 changes: 73 additions & 45 deletions tests/distributed/test_mp_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,19 @@


class NeighborSampler(object):
def __init__(self, g, fanouts, sample_neighbors):
def __init__(
self,
g,
fanouts,
sample_neighbors,
use_graphbolt=False,
return_eids=False,
):
self.g = g
self.fanouts = fanouts
self.sample_neighbors = sample_neighbors
self.use_graphbolt = use_graphbolt
self.return_eids = return_eids

def sample_blocks(self, seeds):
import torch as th
Expand All @@ -35,13 +44,16 @@ def sample_blocks(self, seeds):
for fanout in self.fanouts:
# For each seed node, sample ``fanout`` neighbors.
frontier = self.sample_neighbors(
self.g, seeds, fanout, replace=True
self.g, seeds, fanout, use_graphbolt=self.use_graphbolt
)
# Then we compact the frontier into a bipartite graph for
# message passing.
block = dgl.to_block(frontier, seeds)
# Obtain the seed nodes for next layer.
seeds = block.srcdata[dgl.NID]
if frontier.num_edges() > 0:
if not self.use_graphbolt or self.return_eids:
block.edata[dgl.EID] = frontier.edata[dgl.EID]

blocks.insert(0, block)
return blocks
Expand All @@ -53,6 +65,7 @@ def start_server(
part_config,
disable_shared_mem,
num_clients,
use_graphbolt=False,
):
print("server: #clients=" + str(num_clients))
g = DistGraphServer(
Expand All @@ -63,6 +76,7 @@ def start_server(
part_config,
disable_shared_mem=disable_shared_mem,
graph_format=["csc", "coo"],
use_graphbolt=use_graphbolt,
)
g.start()

Expand All @@ -75,30 +89,36 @@ def start_dist_dataloader(
drop_last,
orig_nid,
orig_eid,
group_id=0,
use_graphbolt=False,
return_eids=False,
):
import dgl
import torch as th

os.environ["DGL_GROUP_ID"] = str(group_id)
dgl.distributed.initialize(ip_config)
gpb = None
disable_shared_mem = num_server > 0
disable_shared_mem = num_server > 1
if disable_shared_mem:
_, _, _, gpb, _, _, _ = load_partition(part_config, rank)
num_nodes_to_sample = 202
batch_size = 32
train_nid = th.arange(num_nodes_to_sample)
dist_graph = DistGraph("test_mp", gpb=gpb, part_config=part_config)

for i in range(num_server):
part, _, _, _, _, _, _ = load_partition(part_config, i)
dist_graph = DistGraph(
"test_sampling",
gpb=gpb,
part_config=part_config,
use_graphbolt=use_graphbolt,
)

# Create sampler
sampler = NeighborSampler(
dist_graph, [5, 10], dgl.distributed.sample_neighbors
dist_graph,
[5, 10],
dgl.distributed.sample_neighbors,
use_graphbolt=use_graphbolt,
return_eids=return_eids,
)

# Enable santity check in distributed sampling.
os.environ["DGL_DIST_DEBUG"] = "1"

# We need to test creating DistDataLoader multiple times.
for i in range(2):
# Create DataLoader for constructing blocks
Expand All @@ -113,7 +133,7 @@ def start_dist_dataloader(
groundtruth_g = CitationGraphDataset("cora")[0]
max_nid = []

for epoch in range(2):
for _ in range(2):
for idx, blocks in zip(
range(0, num_nodes_to_sample, batch_size), dataloader
):
Expand All @@ -129,6 +149,16 @@ def start_dist_dataloader(
src_nodes_id, dst_nodes_id
)
assert np.all(F.asnumpy(has_edges))

if use_graphbolt and not return_eids:
continue
eids = orig_eid[block.edata[dgl.EID]]
expected_eids = groundtruth_g.edge_ids(
src_nodes_id, dst_nodes_id
)
assert th.equal(
eids, expected_eids
), f"{eids} != {expected_eids}"
if drop_last:
assert (
np.max(max_nid)
Expand Down Expand Up @@ -311,23 +341,22 @@ def check_neg_dataloader(g, num_server, num_workers):
assert p.exitcode == 0


@unittest.skip(reason="Skip due to glitch in CI")
@pytest.mark.parametrize("num_server", [3])
@pytest.mark.parametrize("num_server", [1])
@pytest.mark.parametrize("num_workers", [0, 4])
@pytest.mark.parametrize("drop_last", [True, False])
@pytest.mark.parametrize("num_groups", [1])
def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
@pytest.mark.parametrize("drop_last", [False, True])
@pytest.mark.parametrize("use_graphbolt", [False, True])
@pytest.mark.parametrize("return_eids", [False, True])
def test_dist_dataloader(
num_server, num_workers, drop_last, use_graphbolt, return_eids
):
reset_envs()
# No multiple partitions on single machine for
# multiple client groups in case of race condition.
if num_groups > 1:
num_server = 1
os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
with tempfile.TemporaryDirectory() as test_dir:
ip_config = "ip_config.txt"
generate_ip_config(ip_config, num_server, num_server)

g = CitationGraphDataset("cora")[0]
print(g.idtype)
num_parts = num_server
num_hops = 1

Expand All @@ -339,6 +368,8 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
num_hops=num_hops,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=return_eids,
)

part_config = os.path.join(test_dir, "test_sampling.json")
Expand All @@ -353,36 +384,33 @@ def test_dist_dataloader(num_server, num_workers, drop_last, num_groups):
part_config,
num_server > 1,
num_workers + 1,
use_graphbolt,
),
)
p.start()
time.sleep(1)
pserver_list.append(p)

os.environ["DGL_DIST_MODE"] = "distributed"
os.environ["DGL_NUM_SAMPLER"] = str(num_workers)
ptrainer_list = []
num_trainers = 1
for trainer_id in range(num_trainers):
for group_id in range(num_groups):
p = ctx.Process(
target=start_dist_dataloader,
args=(
trainer_id,
ip_config,
part_config,
num_server,
drop_last,
orig_nid,
orig_eid,
group_id,
),
)
p.start()
time.sleep(
1
) # avoid race condition when instantiating DistGraph
ptrainer_list.append(p)
p = ctx.Process(
target=start_dist_dataloader,
args=(
trainer_id,
ip_config,
part_config,
num_server,
drop_last,
orig_nid,
orig_eid,
use_graphbolt,
return_eids,
),
)
p.start()
time.sleep(1) # avoid race condition when instantiating DistGraph
ptrainer_list.append(p)

for p in ptrainer_list:
p.join()
Expand Down

0 comments on commit b8fc709

Please sign in to comment.