Skip to content

Commit

Permalink
Merge branch 'master' into spot_target
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Dec 11, 2023
2 parents 6894c64 + d873acc commit 363fdf4
Show file tree
Hide file tree
Showing 12 changed files with 667 additions and 17 deletions.
1 change: 1 addition & 0 deletions python/dgl/graphbolt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .internal import (
compact_csc_format,
unique_and_compact,
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)
from .utils import add_reverse_edges, exclude_seed_edges
Expand Down
1 change: 1 addition & 0 deletions python/dgl/graphbolt/impl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .fused_csc_sampling_graph import *
from .gpu_cached_feature import *
from .in_subgraph_sampler import *
from .legacy_dataset import *
from .neighbor_sampler import *
from .ondisk_dataset import *
from .ondisk_metadata import *
Expand Down
55 changes: 54 additions & 1 deletion python/dgl/graphbolt/impl/fused_csc_sampling_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,55 @@ def num_nodes(self) -> Union[int, Dict[str, int]]:

return num_nodes_per_type

@property
def num_edges(self) -> Union[int, Dict[str, int]]:
"""The number of edges in the graph.
- If the graph is homogenous, returns an integer.
- If the graph is heterogenous, returns a dictionary.
Returns
-------
Union[int, Dict[str, int]]
The number of edges. Integer indicates the total edges number of a
homogenous graph; dict indicates edges number per edge types of a
heterogenous graph.
Examples
--------
>>> import dgl.graphbolt as gb, torch
>>> total_num_nodes = 5
>>> total_num_edges = 12
>>> ntypes = {"N0": 0, "N1": 1}
>>> etypes = {"N0:R0:N0": 0, "N0:R1:N1": 1,
... "N1:R2:N0": 2, "N1:R3:N1": 3}
>>> indptr = torch.LongTensor([0, 3, 5, 7, 9, 12])
>>> indices = torch.LongTensor([0, 1, 4, 2, 3, 0, 1, 1, 2, 0, 3, 4])
>>> node_type_offset = torch.LongTensor([0, 2, 5])
>>> type_per_edge = torch.LongTensor(
... [0, 0, 2, 2, 2, 1, 1, 1, 3, 1, 3, 3])
>>> metadata = gb.GraphMetadata(ntypes, etypes)
>>> graph = gb.from_fused_csc(indptr, indices, node_type_offset,
... type_per_edge, None, metadata)
>>> print(graph.num_edges)
{'N0:R0:N0': 2, 'N0:R1:N1': 1, 'N1:R2:N0': 2, 'N1:R3:N1': 3}
"""

type_per_edge = self.type_per_edge

# Homogenous.
if type_per_edge is None or self.edge_type_to_id is None:
return self._c_csc_graph.num_edges()

# Heterogenous
bincount = torch.bincount(type_per_edge)
num_edges_per_type = {}
for etype, etype_id in self.edge_type_to_id.items():
if etype_id < len(bincount):
num_edges_per_type[etype] = bincount[etype_id].item()
else:
num_edges_per_type[etype] = 0
return num_edges_per_type

@property
def csc_indptr(self) -> torch.tensor:
"""Returns the indices pointer in the CSC graph.
Expand Down Expand Up @@ -294,7 +343,7 @@ def edge_attributes(self) -> Optional[Dict[str, torch.Tensor]]:
Returns
-------
torch.Tensor or None
Dict[str, torch.Tensor] or None
If present, returns a dictionary of edge attributes. Each key
represents the attribute's name, while the corresponding value
holds the attribute's specific value. The length of each value
Expand Down Expand Up @@ -583,6 +632,7 @@ def sample_neighbors(
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns
-------
FusedSampledSubgraphImpl
Expand Down Expand Up @@ -698,6 +748,7 @@ def _sample_neighbors(
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns
-------
torch.classes.graphbolt.SampledSubgraph
Expand Down Expand Up @@ -767,6 +818,7 @@ def sample_layer_neighbors(
corresponding to each neighboring edge of a node. It must be a 1D
floating-point or boolean tensor, with the number of elements
equalling the total number of edges.
Returns
-------
FusedSampledSubgraphImpl
Expand Down Expand Up @@ -934,6 +986,7 @@ def from_fused_csc(
Edge attributes of the graph, by default None.
metadata: Optional[GraphMetadata], optional
Metadata of the graph, by default None.
Returns
-------
FusedCSCSamplingGraph
Expand Down
156 changes: 156 additions & 0 deletions python/dgl/graphbolt/impl/legacy_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
"""Graphbolt dataset for legacy DGLDataset."""
from typing import List, Union

from dgl.data import AsNodePredDataset, DGLDataset
from ..base import etype_tuple_to_str
from ..dataset import Dataset, Task
from ..itemset import ItemSet, ItemSetDict
from ..sampling_graph import SamplingGraph
from .basic_feature_store import BasicFeatureStore
from .fused_csc_sampling_graph import from_dglgraph
from .ondisk_dataset import OnDiskTask
from .torch_based_feature_store import TorchBasedFeature


class LegacyDataset(Dataset):
"""A Graphbolt dataset for legacy DGLDataset."""

def __init__(self, legacy: DGLDataset):
# Only supports single graph cases.
assert len(legacy) == 1
graph = legacy[0]
# Handle OGB Dataset.
if isinstance(graph, tuple):
graph, _ = graph
if graph.is_homogeneous:
self._init_as_homogeneous_node_pred(legacy)
else:
self._init_as_heterogeneous_node_pred(legacy)

def _init_as_heterogeneous_node_pred(self, legacy: DGLDataset):
def _init_item_set_dict(idx, labels):
item_set_dict = {}
for key in idx.keys():
item_set = ItemSet(
(idx[key], labels[key][idx[key]]),
names=("seed_nodes", "labels"),
)
item_set_dict[key] = item_set
return ItemSetDict(item_set_dict)

# OGB Dataset has the idx split.
if hasattr(legacy, "get_idx_split"):
graph, labels = legacy[0]
split_idx = legacy.get_idx_split()

# Initialize tasks.
tasks = []
metadata = {
"num_classes": legacy.num_classes,
"name": "node_classification",
}
train_set = _init_item_set_dict(split_idx["train"], labels)
validation_set = _init_item_set_dict(split_idx["valid"], labels)
test_set = _init_item_set_dict(split_idx["test"], labels)
task = OnDiskTask(metadata, train_set, validation_set, test_set)
tasks.append(task)
self._tasks = tasks

item_set_dict = {}
for ntype in graph.ntypes:
item_set = ItemSet(graph.num_nodes(ntype), names="seed_nodes")
item_set_dict[ntype] = item_set
self._all_nodes_set = ItemSetDict(item_set_dict)

features = {}
for ntype in graph.ntypes:
for name in graph.nodes[ntype].data.keys():
tensor = graph.nodes[ntype].data[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
features[("node", ntype, name)] = TorchBasedFeature(tensor)
for etype in graph.canonical_etypes:
for name in graph.edges[etype].data.keys():
tensor = graph.edges[etype].data[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
gb_etype = etype_tuple_to_str(etype)
features[("edge", gb_etype, name)] = TorchBasedFeature(
tensor
)
self._feature = BasicFeatureStore(features)
self._graph = from_dglgraph(graph, is_homogeneous=False)
self._dataset_name = legacy.name
else:
raise NotImplementedError(
"Only support heterogeneous ogn node pred dataset"
)

def _init_as_homogeneous_node_pred(self, legacy: DGLDataset):
legacy = AsNodePredDataset(legacy)

# Initialize tasks.
tasks = []
metadata = {
"num_classes": legacy.num_classes,
"name": "node_classification",
}
train_labels = legacy[0].ndata["label"][legacy.train_idx]
validation_labels = legacy[0].ndata["label"][legacy.val_idx]
test_labels = legacy[0].ndata["label"][legacy.test_idx]
train_set = ItemSet(
(legacy.train_idx, train_labels),
names=("seed_nodes", "labels"),
)
validation_set = ItemSet(
(legacy.val_idx, validation_labels),
names=("seed_nodes", "labels"),
)
test_set = ItemSet(
(legacy.test_idx, test_labels), names=("seed_nodes", "labels")
)
task = OnDiskTask(metadata, train_set, validation_set, test_set)
tasks.append(task)
self._tasks = tasks

num_nodes = legacy[0].num_nodes()
self._all_nodes_set = ItemSet(num_nodes, names="seed_nodes")
features = {}
for name in legacy[0].ndata.keys():
tensor = legacy[0].ndata[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
features[("node", None, name)] = TorchBasedFeature(tensor)
for name in legacy[0].edata.keys():
tensor = legacy[0].edata[name]
if tensor.dim() == 1:
tensor = tensor.view(-1, 1)
features[("edge", None, name)] = TorchBasedFeature(tensor)
self._feature = BasicFeatureStore(features)
self._graph = from_dglgraph(legacy[0], is_homogeneous=True)
self._dataset_name = legacy.name

@property
def tasks(self) -> List[Task]:
"""Return the tasks."""
return self._tasks

@property
def graph(self) -> SamplingGraph:
"""Return the graph."""
return self._graph

@property
def feature(self) -> BasicFeatureStore:
"""Return the feature."""
return self._feature

@property
def dataset_name(self) -> str:
"""Return the dataset name."""
return self._dataset_name

@property
def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]:
"""Return the itemset containing all nodes."""
return self._all_nodes_set
42 changes: 30 additions & 12 deletions python/dgl/graphbolt/impl/neighbor_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import torch
from torch.utils.data import functional_datapipe

from ..internal import compact_csc_format, unique_and_compact_node_pairs
from ..internal import (
compact_csc_format,
unique_and_compact_csc_formats,
unique_and_compact_node_pairs,
)

from ..subgraph_sampler import SubgraphSampler
from .sampled_subgraph_impl import FusedSampledSubgraphImpl, SampledSubgraphImpl
Expand Down Expand Up @@ -123,17 +127,31 @@ def _sample_subgraphs(self, seeds):
)
if self.deduplicate:
if self.output_cscformat:
raise RuntimeError("Not implemented yet.")
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(subgraph.node_pairs, seeds)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
(
original_row_node_ids,
compacted_csc_format,
) = unique_and_compact_csc_formats(
subgraph.node_pairs, seeds
)
subgraph = SampledSubgraphImpl(
node_pairs=compacted_csc_format,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
compacted_node_pairs,
) = unique_and_compact_node_pairs(
subgraph.node_pairs, seeds
)
subgraph = FusedSampledSubgraphImpl(
node_pairs=compacted_node_pairs,
original_column_node_ids=seeds,
original_row_node_ids=original_row_node_ids,
original_edge_ids=subgraph.original_edge_ids,
)
else:
(
original_row_node_ids,
Expand Down

0 comments on commit 363fdf4

Please sign in to comment.