Skip to content

Commit

Permalink
Merge branch 'master' into untyped_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jan 22, 2024
2 parents 6a74b84 + d67dae1 commit 7cedd70
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 58 deletions.
17 changes: 10 additions & 7 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,11 @@ def create_dataloader(
# [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
if not args.cpu_sampling:
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.cpu_sampling:
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device)

dataloader = gb.DataLoader(datapipe, args.num_workers)
Expand Down Expand Up @@ -276,7 +276,7 @@ def run(rank, world_size, args, devices, dataset):
)

# Pin the graph and features to enable GPU access.
if not args.cpu_sampling:
if args.storage_device == "pinned":
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

Expand Down Expand Up @@ -388,15 +388,17 @@ def parse_args():
type=str,
default="10,10,10",
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 15,10,5",
" identical with the number of layers in your model. Default: 10,10,10",
)
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--cpu-sampling",
action="store_true",
help="Disables GPU sampling and utilizes the CPU for dataloading.",
"--mode",
default="pinned-cuda",
choices=["cpu-cuda", "pinned-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
return parser.parse_args()

Expand All @@ -406,6 +408,7 @@ def parse_args():
if not torch.cuda.is_available():
print(f"Multi-gpu training needs to be in gpu mode.")
exit(0)
args.storage_device, _ = args.mode.split("-")

devices = list(map(int, args.gpu.split(",")))
world_size = len(devices)
Expand Down
15 changes: 9 additions & 6 deletions python/dgl/graphbolt/impl/uniform_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,12 @@ def __init__(
super().__init__(datapipe, negative_ratio)
self.graph = graph

def _sample_with_etype(self, node_pairs, etype=None):
return self.graph.sample_negative_edges_uniform(
etype,
node_pairs,
self.negative_ratio,
)
def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
if not use_seeds:
return self.graph.sample_negative_edges_uniform(
etype,
node_pairs,
self.negative_ratio,
)
else:
raise NotImplementedError("Not implemented yet.")
9 changes: 8 additions & 1 deletion python/dgl/graphbolt/item_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ItemShufflerAndBatcher:
rank : int
The rank of the current replica. Applies only when `distributed` is
True.
rng : np.random.Generator
The random number generator to use for shuffling.
"""

def __init__(
Expand All @@ -128,6 +130,7 @@ def __init__(
drop_uneven_inputs: Optional[bool] = False,
world_size: Optional[int] = 1,
rank: Optional[int] = 0,
rng: Optional[np.random.Generator] = None,
):
self._item_set = item_set
self._shuffle = shuffle
Expand All @@ -142,6 +145,7 @@ def __init__(
self._drop_uneven_inputs = drop_uneven_inputs
self._num_replicas = world_size
self._rank = rank
self._rng = rng

def _collate_batch(self, buffer, indices, offsets=None):
"""Collate a batch from the buffer. For internal use only."""
Expand Down Expand Up @@ -216,7 +220,7 @@ def __iter__(self):
buffer = self._item_set[start_offset + start : start_offset + end]
indices = torch.arange(end - start)
if self._shuffle:
np.random.shuffle(indices.numpy())
self._rng.shuffle(indices.numpy())
offsets = self._calculate_offsets(buffer)
for i in range(0, len(indices), self._batch_size):
if output_count <= 0:
Expand Down Expand Up @@ -494,6 +498,7 @@ def __init__(
self._drop_uneven_inputs = False
self._world_size = None
self._rank = None
self._rng = np.random.default_rng()

def _organize_items(self, data_pipe) -> None:
# Shuffle before batch.
Expand Down Expand Up @@ -529,6 +534,7 @@ def _collate(batch):

def __iter__(self) -> Iterator:
if self._use_indexing:
seed = self._rng.integers(0, np.iinfo(np.int32).max)
data_pipe = IterableWrapper(
ItemShufflerAndBatcher(
self._item_set,
Expand All @@ -540,6 +546,7 @@ def __iter__(self) -> Iterator:
drop_uneven_inputs=self._drop_uneven_inputs,
world_size=self._world_size,
rank=self._rank,
rng=np.random.default_rng(seed),
)
)
else:
Expand Down
48 changes: 28 additions & 20 deletions python/dgl/graphbolt/negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,24 @@ def _sample(self, minibatch):
An instance of 'MiniBatch' encompasses both positive and negative
samples.
"""
node_pairs = minibatch.node_pairs
assert node_pairs is not None
if isinstance(node_pairs, Mapping):
minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
minibatch, self._sample_with_etype(pos_pairs, etype), etype
)
if minibatch.seeds is None:
node_pairs = minibatch.node_pairs
assert node_pairs is not None
if isinstance(node_pairs, Mapping):
minibatch.negative_srcs, minibatch.negative_dsts = {}, {}
for etype, pos_pairs in node_pairs.items():
self._collate(
minibatch,
self._sample_with_etype(pos_pairs, etype),
etype,
)
else:
self._collate(minibatch, self._sample_with_etype(node_pairs))
else:
self._collate(minibatch, self._sample_with_etype(node_pairs))
raise NotImplementedError("Not implemented yet.")
return minibatch

def _sample_with_etype(self, node_pairs, etype=None):
def _sample_with_etype(self, node_pairs, etype=None, use_seeds=False):
"""Generate negative pairs for a given etype form positive pairs
for a given etype.
Expand Down Expand Up @@ -102,14 +107,17 @@ def _collate(self, minibatch, neg_pairs, etype=None):
etype : str
Canonical edge type.
"""
neg_src, neg_dst = neg_pairs
if neg_src is not None:
neg_src = neg_src.view(-1, self.negative_ratio)
if neg_dst is not None:
neg_dst = neg_dst.view(-1, self.negative_ratio)
if etype is not None:
minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
if minibatch.seeds is None:
neg_src, neg_dst = neg_pairs
if neg_src is not None:
neg_src = neg_src.view(-1, self.negative_ratio)
if neg_dst is not None:
neg_dst = neg_dst.view(-1, self.negative_ratio)
if etype is not None:
minibatch.negative_srcs[etype] = neg_src
minibatch.negative_dsts[etype] = neg_dst
else:
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
else:
minibatch.negative_srcs = neg_src
minibatch.negative_dsts = neg_dst
raise NotImplementedError("Not implemented yet.")
71 changes: 47 additions & 24 deletions python/dgl/graphbolt/sampled_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,34 @@ def sampled_csc(
) -> Union[CSCFormatBase, Dict[str, CSCFormatBase],]:
"""Returns the node pairs representing edges in csc format.
- If `sampled_csc` is a CSCFormatBase: It should be in the csc format.
`indptr` stores the index in the data array where each column
starts. `indices` stores the row indices of the non-zero elements.
`indptr` stores the index in the data array where each column
starts. `indices` stores the row indices of the non-zero elements.
- If `sampled_csc` is a dictionary: The keys should be edge type and
the values should be corresponding node pairs. The ids inside
is heterogeneous ids."""
the values should be corresponding node pairs. The ids inside is
heterogeneous ids.
Examples
--------
1. Homogeneous graph.
>>> import dgl.graphbolt as gb
>>> import torch
>>> sampled_csc = gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))
>>> print(sampled_csc)
CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 2]),
)
2. Heterogeneous graph.
sampled_csc = {"A:relation:B": gb.CSCFormatBase(
... indptr=torch.tensor([0, 1, 2, 3]),
... indices=torch.tensor([0, 1, 2]))}
>>> print(sampled_csc)
{'A:relation:B': CSCFormatBase(indptr=tensor([0, 1, 2, 3]),
indices=tensor([0, 1, 2]),
)}
"""
raise NotImplementedError

@property
Expand All @@ -46,11 +69,11 @@ def original_column_node_ids(
Column's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the column.
- If `original_column_node_ids` is a tensor: It represents the
original node ids.
- If `original_column_node_ids` is a tensor: It represents the original
node ids.
- If `original_column_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
node type and the values should be corresponding original
heterogeneous node ids.
If present, it means column IDs are compacted, and `sampled_csc`
column IDs match these compacted ones.
"""
Expand All @@ -64,11 +87,11 @@ def original_row_node_ids(
Row's reverse node ids in the original graph. A graph structure
can be treated as a coordinated row and column pair, and this is
the mapped ids of the row.
- If `original_row_node_ids` is a tensor: It represents the
original node ids.
- If `original_row_node_ids` is a dictionary: The keys should be
node type and the values should be corresponding original
heterogeneous node ids.
- If `original_row_node_ids` is a tensor: It represents the original
node ids.
- If `original_row_node_ids` is a dictionary: The keys should be node
type and the values should be corresponding original heterogeneous
node ids.
If present, it means row IDs are compacted, and `sampled_csc`
row IDs match these compacted ones."""
return None
Expand All @@ -77,12 +100,12 @@ def original_row_node_ids(
def original_edge_ids(self) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
"""Returns corresponding reverse edge ids the original graph.
Reverse edge ids in the original graph. This is useful when edge
features are needed.
- If `original_edge_ids` is a tensor: It represents the
original edge ids.
- If `original_edge_ids` is a dictionary: The keys should be
edge type and the values should be corresponding original
heterogeneous edge ids.
features are needed.
- If `original_edge_ids` is a tensor: It represents the original edge
ids.
- If `original_edge_ids` is a dictionary: The keys should be edge type
and the values should be corresponding original heterogeneous edge
ids.
"""
return None

Expand All @@ -105,12 +128,12 @@ def exclude_edges(
----------
self : SampledSubgraph
The sampled subgraph.
edges : Union[Dict[str, Tuple[torch.Tensor, torch.Tensor]],
Tuple[torch.Tensor, torch.Tensor]]
edges : Union[Tuple[torch.Tensor, torch.Tensor],
Dict[str, Tuple[torch.Tensor, torch.Tensor]]]
Edges to exclude. If sampled subgraph is homogeneous, then `edges`
should be a pair of tensors representing the edges to exclude. If
sampled subgraph is heterogeneous, then `edges` should be a dictionary
of edge types and the corresponding edges to exclude.
sampled subgraph is heterogeneous, then `edges` should be a
dictionary of edge types and the corresponding edges to exclude.
assume_num_node_within_int32: bool
If True, assumes the value of node IDs in the provided `edges` fall
within the int32 range, which can significantly enhance computation
Expand All @@ -119,7 +142,7 @@ def exclude_edges(
Returns
-------
SampledSubgraph
An instance of a class that inherits from `SampledSubgraph`.
An instance of a class that inherits from `SampledSubgraph`.
Examples
--------
Expand Down
27 changes: 27 additions & 0 deletions tests/python/pytorch/graphbolt/impl/test_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,33 @@ def test_NegativeSampler_invoke():
next(iter(negative_sampler))


def test_UniformNegativeSampler_seeds_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
num_seeds = 30
item_set = gb.ItemSet(
torch.arange(0, 2 * num_seeds).reshape(-1, 2), names="seeds"
)
batch_size = 10
item_sampler = gb.ItemSampler(item_set, batch_size=batch_size)
negative_ratio = 2
# Invoke UniformNegativeSampler via class constructor.
negative_sampler = gb.UniformNegativeSampler(
item_sampler,
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))
# Invoke UniformNegativeSampler via functional form.
negative_sampler = item_sampler.sample_uniform_negative(
graph,
negative_ratio,
)
with pytest.raises(NotImplementedError):
next(iter(negative_sampler))


def test_UniformNegativeSampler_invoke():
# Instantiate graph and required datapipes.
graph = gb_test_utils.rand_csc_graph(100, 0.05, bidirection_edge=True)
Expand Down

0 comments on commit 7cedd70

Please sign in to comment.