Skip to content

Commit

Permalink
Merge branch 'master' into untyped_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Feb 21, 2024
2 parents bb9d2da + 364cb71 commit 77473a7
Show file tree
Hide file tree
Showing 14 changed files with 842 additions and 646 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch/hgp_sl/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
"""
import dgl
import torch
from dgl._sparse_ops import _gsddmm, _gspmm
from dgl.backend import astype
from dgl.base import ALL, is_all
from dgl.heterograph_index import HeteroGraphIndex
from dgl.sparse import _gsddmm, _gspmm
from torch import Tensor
from torch.autograd import Function

Expand Down
4 changes: 2 additions & 2 deletions examples/pytorch/mvgrl/graph/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def collate(samples):

print("Epoch {}, Loss {:.4f}".format(epoch, loss_all))

if loss < best:
best = loss
if loss_all < best:
best = loss_all
best_t = epoch
cnt_wait = 0
th.save(model.state_dict(), f"{args.dataname}.pkl")
Expand Down
29 changes: 27 additions & 2 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def create_dataloader(
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(
datapipe = getattr(datapipe, args.sample_mode)(
graph, fanout if job != "infer" else [-1]
)

Expand Down Expand Up @@ -157,7 +157,11 @@ def create_dataloader(
# [Role]:
# Initialize a multi-process dataloader to load the data in parallel.
############################################################################
dataloader = gb.DataLoader(datapipe, num_workers=num_workers)
dataloader = gb.DataLoader(
datapipe,
num_workers=num_workers,
overlap_graph_fetch=args.overlap_graph_fetch,
)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down Expand Up @@ -357,13 +361,34 @@ def parse_args():
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 10,10,10",
)
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
help="The dataset we can use for node classification example. Currently"
"dataset ogbn-products, ogbn-arxiv, ogbn-papers100M is supported.",
)
parser.add_argument(
"--mode",
default="pinned-cuda",
choices=["cpu-cpu", "cpu-cuda", "pinned-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
parser.add_argument(
"--sample-mode",
default="sample_neighbor",
choices=["sample_neighbor", "sample_layer_neighbor"],
help="The sampling function when doing layerwise sampling.",
)
parser.add_argument(
"--overlap-graph-fetch",
action="store_true",
help="An option for enabling overlap_graph_fetch in graphbolt dataloader."
"If True, the data loader will overlap the UVA graph fetching operations"
"with the rest of operations by using an alternative CUDA stream. Disabled"
"by default.",
)
return parser.parse_args()


Expand Down

0 comments on commit 77473a7

Please sign in to comment.