Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DoNotMerge] Integrate GraphBolt with DistDGL #7006

Open
wants to merge 30 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d94e52f
[gb_distdgl] add demo py
Rhett-Ying Oct 23, 2023
9de41c6
[gb_distdgl] enable to store node_attributes into CSCSamplingGraph
Rhett-Ying Oct 23, 2023
ddce1d4
[gb_distdgl] refine convert_dgl_partition_to_csc_sampling_graph to co…
Rhett-Ying Oct 23, 2023
59825dc
[gb_distdgl] TODO: control dtype more rigidly when construct CSCSampl…
Rhett-Ying Oct 23, 2023
174dc37
[gb_distdgl] add graph file size of ogbn-mag for comparision btw DGL …
Rhett-Ying Oct 23, 2023
4b767b3
[gb_distdgl] Add use_graphbolt to control save CSCSamplingGraph when …
Rhett-Ying Oct 24, 2023
d76004f
[gb_distdgl] enable copy to sharedMem and back
Rhett-Ying Oct 24, 2023
222dd2b
[gb_distdgl] successfully load graph in server and client from shared…
Rhett-Ying Oct 24, 2023
1cabe94
[gb_distdgl] update todo list
Rhett-Ying Oct 24, 2023
70c3a9d
[gb_distdgl] dataloader is created
Rhett-Ying Oct 25, 2023
9362afe
[gb_distdgl] create ItemSet with dict input
Rhett-Ying Oct 25, 2023
f440df1
[gb_distdgl] _distributed_access works with graph.sample hacked
Rhett-Ying Oct 25, 2023
93feb0b
[gb_distdgl] Not worked as crashed in csc_sampling_graph::_check_samp…
Rhett-Ying Oct 25, 2023
8da2a4a
[gb_distdgl] only replace sample_neigbors with GB and it does not crash
Rhett-Ying Oct 26, 2023
1b06f14
[gb_distdgl] train is ready though acc and time drops
Rhett-Ying Oct 26, 2023
b3eecae
multiply fanout with num_etypes
Rhett-Ying Oct 26, 2023
896ad9b
[WAHAHA] sample_etype_neighbors is applied truely except metadata shm…
Rhett-Ying Oct 26, 2023
5c8c612
[WAHAHA] gb_metadata is generated on-fly for obtain CSCSamplingGraph …
Rhett-Ying Oct 26, 2023
438d8ca
------ LAUNCH DistDGL with GraphBolt on heterograph -------
Rhett-Ying Oct 26, 2023
6d5d240
standalone with use_graphbolt is not supported
Rhett-Ying Oct 27, 2023
a99b192
format dtype when converting to CSCSamplingGraph
Rhett-Ying Oct 30, 2023
c0099a1
update README about partition size
Rhett-Ying Oct 30, 2023
ebb6cea
clean up unnecessary return edge types
Rhett-Ying Oct 30, 2023
be3f7f2
update graph partition size with eids
Rhett-Ying Oct 30, 2023
591f60e
[WAHAHA] node classfication on homogeneous graph with GraphBolt is ready
Rhett-Ying Oct 31, 2023
50c7be2
Homo/Hetero + NC is verified
Rhett-Ying Oct 31, 2023
aac0d6f
add comments for EID in DGL block creation
Rhett-Ying Nov 1, 2023
f806ece
add assertion for num_samplers > 0 as not supported yet
Rhett-Ying Nov 1, 2023
b15c931
add crash log for num_samplers>0
Rhett-Ying Nov 1, 2023
19f12bb
add script to check mem footprint
Rhett-Ying Nov 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,6 @@ pipeline {
steps {
unit_distributed_linux('pytorch', 'cpu')
}
when { expression { false } }
}
}
post {
Expand Down
52 changes: 52 additions & 0 deletions check_mem_footprint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import dgl
from dgl.distributed import load_partition
import psutil
import os
import argparse
import gc

parser = argparse.ArgumentParser(description="check memory footprint")
parser.add_argument(
"--part_config",
type=str,
help="partition config file",
)
parser.add_argument(
"--graphbolt",
action="store_true",
help="use graphbolt",
)
parser.add_argument(
"--part_id",
type=int,
help="partition id",
)

args = parser.parse_args()

use_graphbolt = args.graphbolt
part_id = args.part_id

prev_rss = psutil.Process(os.getpid()).memory_info().rss
(
client_g,
_,
_,
gpb,
graph_name,
ntypes,
etypes,
) = load_partition(
args.part_config,
part_id,
load_feats=False,
use_graphbolt=use_graphbolt,
)
if not use_graphbolt:
graph_format=("csc")
client_g = client_g.formats(graph_format)
client_g.create_formats_()
new_rss = psutil.Process(os.getpid()).memory_info().rss
print(f"[PartID_{part_id}] Loaded {graph_name} with use_graphbolt[{use_graphbolt}] in size[{(new_rss - prev_rss)/1024/1024 : .0f} MB]")
client_g = None
gc.collect()
101 changes: 101 additions & 0 deletions dt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import dgl
import dgl.graphbolt as gb
import numpy as np
import torch as th

# [TODO][P0] Set up distributed environment.

"""
num_trainers = 8
num_servers = 4
num_samplers = 0
part_config = ./ogbn-products.json
ip_config = ./ip_config.txt
"""

args = {}

# Initialize distributed environment
dgl.distributed.initialize(args.ip_config)
th.distributed.init_process_group(backend=args.backend)
# [TODO][P0] Convert dgl partitioned graphs to graphbolt.CSCSamplingGraph.
# done@2023-10-23 16:49:00
# see details in: https://github.com/Rhett-Ying/dgl/commits/gb_distdgl
# ddce1d42de016be040cd0f8a5e71f2a10148de82
'''
In [1]: part_config='/home/ubuntu/workspace/dgl_2/data/ogbn-mag.json'
In [3]: dgl.distributed.convert_dgl_partition_to_csc_sampling_graph(part_config, store_orig_nids=True)
In [7]: !ls data/part0 -lh
total 1.1G
-rw-rw-r-- 1 ubuntu ubuntu 207M Oct 23 08:44 csc_sampling_graph.tar
-rw-rw-r-- 1 ubuntu ubuntu 694M Oct 23 02:47 graph.dgl

In [8]: !ls data/part1 -lh
total 1.1G
-rw-rw-r-- 1 ubuntu ubuntu 202M Oct 23 08:44 csc_sampling_graph.tar
-rw-rw-r-- 1 ubuntu ubuntu 678M Oct 23 02:47 graph.dgl
'''

# [TODO][P0] Load `CSCSamplingGraph` into `DistGraph`.
# done@2023-10-24 15:10:00
# see details in: https://github.com/Rhett-Ying/dgl/commits/gb_distdgl
# 222dd2bd51084cc4f242148b0a7e6e5d91e0ae80
## NID/EIDs are required.
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)

# Generate train/val/test splits
##############
# train/val/test splits could be generated offline, then `train/val/test_masks`
# could be offloaded.
# No change is required as `node_split` requires graph parition book and
# masks only.
# This should be part of `OnDiskDataset::TVT`.
# [TODO][P1]: Add a standalone API to generate train/val/test splits.
##############
gpb = g.get_partition_book()
train_nids = dgl.distributed.node_split(g.ndata["train_masks"], gpb)
val_nids = dgl.distributed.node_split(g.ndata["val_masks"], gpb)
test_nids = dgl.distributed.node_split(g.ndata["test_masks"], gpb)
all_nids = dgl.distributed.node_split(th.arange(g.num_nodes()), gpb)

# [TODO][P2] How to handle feature data such as 'feat', 'mask'?
# Just use `g.ndata['feat']` for now. As no more memory could be offloaded.
# GB: feat_data = gb.OnDiskDataset().feature
# DistDGL: feat_data = g.ndata['feat'] # DistTensor


# Train.
##############
# GraphBolt version
# [TODO][P0] Add `gb.distributed_sample_neighbor` API.
# [TODO][P0] `remote_sample_neighbor()` returns original global node pairs + eids.
# [TODO][P0] Upldate `dgl.distributed.merge_graphs` API.
# https://github.com/dmlc/dgl/blob/7439b7e73bdb85b4285ab01f704ac5a4f77c927e/python/dgl/distributed/graph_services.py#L440.
##############
"""
datapipe = gb.ItemSampler(item_set, batch_size=batch_size, shuffle=shuffle)
datapipe = datapipe.sample_neighbor(g._graph, fanouts=fanouts)
datapipe = datapipe.to_dgl()
device = th.device("cpu")
datapipe = datapipe.copy_to(device)
data_loader = gb.MultiProcessDataLoader(datapipe, num_workers=num_workers)
"""
sampler = dgl.dataloading.NeighborSampler([25, 10])
train_dataloader = dgl.distributed.DistDataLoader(
g, train_nids, sampler=sampler, batch_size=args.batch_size, shuffle=True
)
model = None
for mini_batch in train_dataloader:
in_feats = g.ndata["feat"][mini_batch.input_nodes]
labels = g.ndata["label"][mini_batch.output_nodes]
_ = model(mini_batch, in_feats)

# Evaluate.
model.eval()
sampler = dgl.dataloading.NeighborSampler([-1])
val_dataloader = dgl.distributed.DistDataLoader(
g, val_nids, sampler=sampler, batch_size=args.batch_size, shuffle=False
)
test_dataloader = dgl.distributed.DistDataLoader(
g, test_nids, sampler=sampler, batch_size=args.batch_size, shuffle=False
)
80 changes: 80 additions & 0 deletions examples/distributed/graphsage/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,83 @@
## DistDGL with GraphBolt(Homograph + Node Classification)

### How to partition graph

#### Partition from original dataset with `dgl.distributed.partition_graph()`

```
DGL_HOME=/home/ubuntu/workspace/dgl_2 DGL_LIBRARY_PATH=$DGL_HOME/build PYTHONPATH=tests:$DGL_HOME/python:tests/python/pytorch/graphbolt:$PYTHONPATH python3 examples/distributed/graphsage/partition_graph.py --dataset ogbn-products --num_parts 2 --balance_train --balance_edges --graphbolt
```

#### Convert existing partitions into GraphBolt formats

```
DGL_LIBRARY_PATH=$DGL_HOME/build PYTHONPATH=tests:$DGL_HOME/python:tests/python/pytorch/graphbolt:$PYTHONPATH python3 -c "from dgl.distributed import convert_dgl_partition_to_csc_sampling_graph as f;f('data/ogbn-products.json')"
```

#### Partition sizes compared between GraphBolt and DistDGL

`csc_sampling_graph.tar` is the GraphBolt partitions.
`graph.dgl` is the original DistDGL partitions, namely, DGLGraph.

###### ogbn-products
homogeneous, ~2.4M nodes, ~123.7M edges(reverse edges are added), 2 parts.

| DGL(GB) | GraphBolt w/o EIDs(MB) | GraphBolt w/ EIDs(MB) |
| --- | ------------------ | ----------------- |
| 1.6/1.7 | 258/272 | 502/530 |

```
-rw-rw-r-- 1 ubuntu ubuntu 258M Oct 31 01:56 homo_data/part0/csc_sampling_graph.tar
-rw-rw-r-- 1 ubuntu ubuntu 502M Oct 31 04:45 homo_data/part0/csc_sampling_graph_eids.tar
-rw-rw-r-- 1 ubuntu ubuntu 24 Oct 31 00:51 homo_data/part0/edge_feat.dgl
-rw-rw-r-- 1 ubuntu ubuntu 1.6G Oct 31 00:51 homo_data/part0/graph.dgl
-rw-rw-r-- 1 ubuntu ubuntu 501M Oct 31 00:51 homo_data/part0/node_feat.dgl
-rw-rw-r-- 1 ubuntu ubuntu 272M Oct 31 01:56 homo_data/part1/csc_sampling_graph.tar
-rw-rw-r-- 1 ubuntu ubuntu 530M Oct 31 04:45 homo_data/part1/csc_sampling_graph_eids.tar
-rw-rw-r-- 1 ubuntu ubuntu 24 Oct 31 00:51 homo_data/part1/edge_feat.dgl
-rw-rw-r-- 1 ubuntu ubuntu 1.7G Oct 31 00:51 homo_data/part1/graph.dgl
-rw-rw-r-- 1 ubuntu ubuntu 460M Oct 31 00:51 homo_data/part1/node_feat.dgl
```

### Train with GraphBolt partitions
just append `--graphbolt`.

```
python3 /home/ubuntu/workspace/dgl_2/tools/launch.py \
--workspace /home/ubuntu/workspace/dgl_2/examples/distributed/graphsage/ \
--num_trainers 4 \
--num_servers 2 \
--num_samplers 0 \
--part_config /home/ubuntu/workspace/dgl_2/homo_data/ogbn-products.json \
--ip_config /home/ubuntu/workspace/ip_config.txt \
"DGL_LIBRARY_PATH=/home/ubuntu/workspace/dgl_2/build PYTHONPATH=tests:/home/ubuntu/workspace/dgl_2/python:tests/python/pytorch/graphbolt:$PYTHONPATH python3 node_classification.py --graph_name ogbn-products --ip_config /home/ubuntu/workspace/ip_config.txt --num_epochs 3 --eval_every 2 --graphbolt"
```

#### Results
`g4dn.metal` x 2, `ogbn-products`.

DistDGL with GraphBolt takes less time for sampling(from **1.8283s** to **1.4470s**) and for whole epoch(from **4.9259s** to **4.4898s**) while keeping comparable accuracies in validation and test.

##### DistDGL

```
Part 0, Epoch Time(s): 4.9648, sample+data_copy: 1.8283, forward: 0.2912, backward: 1.1307, update: 0.0232, #seeds: 24577, #inputs: 4136843

Summary of node classification(GraphSAGE): GraphName ogbn-products | TrainEpochTime(mean) 4.9259 | TestAccuracy 0.6213
```

##### DistDGL with GraphBolt

```
Part 0, Epoch Time(s): 4.4826, sample+data_copy: 1.4470, forward: 0.2517, backward: 0.9081, update: 0.0175, #seeds: 24577, #inputs: 41369
80

Summary of node classification(GraphSAGE): GraphName ogbn-products | TrainEpochTime(mean) 4.4898 | TestAccuracy 0.6174
```

---------------------------------------


## Distributed training

This is an example of training GraphSage in a distributed fashion. Before training, please install some python libs by pip:
Expand Down
10 changes: 10 additions & 0 deletions examples/distributed/graphsage/dgl_cmd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

python3 /home/ubuntu/workspace/dgl_2/tools/launch.py \
--workspace /home/ubuntu/workspace/dgl_2/examples/distributed/graphsage/ \
--num_trainers 4 \
--num_servers 2 \
--num_samplers 0 \
--part_config /home/ubuntu/workspace/dgl_2/homo_data/ogbn-products.json \
--ip_config /home/ubuntu/workspace/ip_config.txt \
"DGL_LIBRARY_PATH=/home/ubuntu/workspace/dgl_2/build PYTHONPATH=tests:/home/ubuntu/workspace/dgl_2/python:tests/python/pytorch/graphbolt:$PYTHONPATH python3 node_classification.py --graph_name ogbn-products --ip_config /home/ubuntu/workspace/ip_config.txt --num_epochs 3 --eval_every 2"
10 changes: 10 additions & 0 deletions examples/distributed/graphsage/gb_cmd.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

python3 /home/ubuntu/workspace/dgl_2/tools/launch.py \
--workspace /home/ubuntu/workspace/dgl_2/examples/distributed/graphsage/ \
--num_trainers 4 \
--num_servers 2 \
--num_samplers 0 \
--part_config /home/ubuntu/workspace/dgl_2/homo_data/ogbn-products.json \
--ip_config /home/ubuntu/workspace/ip_config.txt \
"DGL_LIBRARY_PATH=/home/ubuntu/workspace/dgl_2/build PYTHONPATH=tests:/home/ubuntu/workspace/dgl_2/python:tests/python/pytorch/graphbolt:$PYTHONPATH python3 node_classification.py --graph_name ogbn-products --ip_config /home/ubuntu/workspace/ip_config.txt --num_epochs 3 --eval_every 2 --graphbolt"
22 changes: 17 additions & 5 deletions examples/distributed/graphsage/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def forward(self, blocks, x):
h = self.dropout(h)
return h

def inference(self, g, x, batch_size, device):
def inference(self, g, x, batch_size, device, use_graphbolt):
"""
Distributed layer-wise inference with the GraphSAGE model on full
neighbors.
Expand Down Expand Up @@ -116,6 +116,7 @@ def inference(self, g, x, batch_size, device):
batch_size=batch_size,
shuffle=False,
drop_last=False,
use_graphbolt=use_graphbolt,
)

for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
Expand Down Expand Up @@ -155,7 +156,7 @@ def compute_acc(pred, labels):
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)


def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device, use_graphbolt):
"""
Evaluate the model on the validation and test set.

Expand Down Expand Up @@ -187,7 +188,7 @@ def evaluate(model, g, inputs, labels, val_nid, test_nid, batch_size, device):
"""
model.eval()
with th.no_grad():
pred = model.inference(g, inputs, batch_size, device)
pred = model.inference(g, inputs, batch_size, device, use_graphbolt)
model.train()
return compute_acc(pred[val_nid], labels[val_nid]), compute_acc(
pred[test_nid], labels[test_nid]
Expand Down Expand Up @@ -219,6 +220,7 @@ def run(args, device, data):
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
use_graphbolt=args.graphbolt,
)
model = DistSAGE(
in_feats,
Expand Down Expand Up @@ -325,6 +327,7 @@ def run(args, device, data):
test_nid,
args.batch_size_eval,
device,
args.graphbolt,
)
print(
f"Part {g.rank()}, Val Acc {val_acc:.4f}, "
Expand All @@ -338,13 +341,16 @@ def main(args):
"""
Main function.
"""
if args.graphbolt:
print("DistDGL with GraphBolt...")
host_name = socket.gethostname()
print(f"{host_name}: Initializing DistDGL.")
dgl.distributed.initialize(args.ip_config)
dgl.distributed.initialize(args.ip_config, use_graphbolt=args.graphbolt)
print(f"{host_name}: Initializing PyTorch process group.")
th.distributed.init_process_group(backend=args.backend)
print(f"{host_name}: Initializing DistGraph.")
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config)
g = dgl.distributed.DistGraph(args.graph_name, part_config=args.part_config,
use_graphbolt=args.graphbolt)
print(f"Rank of {host_name}: {g.rank()}")

# Split train/val/test IDs for each trainer.
Expand Down Expand Up @@ -415,6 +421,12 @@ def main(args):

if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Distributed GraphSAGE.")
parser.add_argument(
"--graphbolt",
default=False,
action="store_true",
help="train with GraphBolt",
)
parser.add_argument("--graph_name", type=str, help="graph name")
parser.add_argument(
"--ip_config", type=str, help="The file for IP configuration"
Expand Down
6 changes: 6 additions & 0 deletions examples/distributed/graphsage/partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def load_ogb(name, root="dataset"):
argparser.add_argument(
"--part_method", type=str, default="metis", help="the partition method"
)
argparser.add_argument(
"--graphbolt",
action="store_true",
help="convert DGL to GraphBolt partitions.",
)
argparser.add_argument(
"--balance_train",
action="store_true",
Expand Down Expand Up @@ -127,4 +132,5 @@ def load_ogb(name, root="dataset"):
balance_ntypes=balance_ntypes,
balance_edges=args.balance_edges,
num_trainers_per_machine=args.num_trainers_per_machine,
use_graphbolt=args.graphbolt,
)
Loading