-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into spot_target
- Loading branch information
Showing
12 changed files
with
667 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
"""Graphbolt dataset for legacy DGLDataset.""" | ||
from typing import List, Union | ||
|
||
from dgl.data import AsNodePredDataset, DGLDataset | ||
from ..base import etype_tuple_to_str | ||
from ..dataset import Dataset, Task | ||
from ..itemset import ItemSet, ItemSetDict | ||
from ..sampling_graph import SamplingGraph | ||
from .basic_feature_store import BasicFeatureStore | ||
from .fused_csc_sampling_graph import from_dglgraph | ||
from .ondisk_dataset import OnDiskTask | ||
from .torch_based_feature_store import TorchBasedFeature | ||
|
||
|
||
class LegacyDataset(Dataset): | ||
"""A Graphbolt dataset for legacy DGLDataset.""" | ||
|
||
def __init__(self, legacy: DGLDataset): | ||
# Only supports single graph cases. | ||
assert len(legacy) == 1 | ||
graph = legacy[0] | ||
# Handle OGB Dataset. | ||
if isinstance(graph, tuple): | ||
graph, _ = graph | ||
if graph.is_homogeneous: | ||
self._init_as_homogeneous_node_pred(legacy) | ||
else: | ||
self._init_as_heterogeneous_node_pred(legacy) | ||
|
||
def _init_as_heterogeneous_node_pred(self, legacy: DGLDataset): | ||
def _init_item_set_dict(idx, labels): | ||
item_set_dict = {} | ||
for key in idx.keys(): | ||
item_set = ItemSet( | ||
(idx[key], labels[key][idx[key]]), | ||
names=("seed_nodes", "labels"), | ||
) | ||
item_set_dict[key] = item_set | ||
return ItemSetDict(item_set_dict) | ||
|
||
# OGB Dataset has the idx split. | ||
if hasattr(legacy, "get_idx_split"): | ||
graph, labels = legacy[0] | ||
split_idx = legacy.get_idx_split() | ||
|
||
# Initialize tasks. | ||
tasks = [] | ||
metadata = { | ||
"num_classes": legacy.num_classes, | ||
"name": "node_classification", | ||
} | ||
train_set = _init_item_set_dict(split_idx["train"], labels) | ||
validation_set = _init_item_set_dict(split_idx["valid"], labels) | ||
test_set = _init_item_set_dict(split_idx["test"], labels) | ||
task = OnDiskTask(metadata, train_set, validation_set, test_set) | ||
tasks.append(task) | ||
self._tasks = tasks | ||
|
||
item_set_dict = {} | ||
for ntype in graph.ntypes: | ||
item_set = ItemSet(graph.num_nodes(ntype), names="seed_nodes") | ||
item_set_dict[ntype] = item_set | ||
self._all_nodes_set = ItemSetDict(item_set_dict) | ||
|
||
features = {} | ||
for ntype in graph.ntypes: | ||
for name in graph.nodes[ntype].data.keys(): | ||
tensor = graph.nodes[ntype].data[name] | ||
if tensor.dim() == 1: | ||
tensor = tensor.view(-1, 1) | ||
features[("node", ntype, name)] = TorchBasedFeature(tensor) | ||
for etype in graph.canonical_etypes: | ||
for name in graph.edges[etype].data.keys(): | ||
tensor = graph.edges[etype].data[name] | ||
if tensor.dim() == 1: | ||
tensor = tensor.view(-1, 1) | ||
gb_etype = etype_tuple_to_str(etype) | ||
features[("edge", gb_etype, name)] = TorchBasedFeature( | ||
tensor | ||
) | ||
self._feature = BasicFeatureStore(features) | ||
self._graph = from_dglgraph(graph, is_homogeneous=False) | ||
self._dataset_name = legacy.name | ||
else: | ||
raise NotImplementedError( | ||
"Only support heterogeneous ogn node pred dataset" | ||
) | ||
|
||
def _init_as_homogeneous_node_pred(self, legacy: DGLDataset): | ||
legacy = AsNodePredDataset(legacy) | ||
|
||
# Initialize tasks. | ||
tasks = [] | ||
metadata = { | ||
"num_classes": legacy.num_classes, | ||
"name": "node_classification", | ||
} | ||
train_labels = legacy[0].ndata["label"][legacy.train_idx] | ||
validation_labels = legacy[0].ndata["label"][legacy.val_idx] | ||
test_labels = legacy[0].ndata["label"][legacy.test_idx] | ||
train_set = ItemSet( | ||
(legacy.train_idx, train_labels), | ||
names=("seed_nodes", "labels"), | ||
) | ||
validation_set = ItemSet( | ||
(legacy.val_idx, validation_labels), | ||
names=("seed_nodes", "labels"), | ||
) | ||
test_set = ItemSet( | ||
(legacy.test_idx, test_labels), names=("seed_nodes", "labels") | ||
) | ||
task = OnDiskTask(metadata, train_set, validation_set, test_set) | ||
tasks.append(task) | ||
self._tasks = tasks | ||
|
||
num_nodes = legacy[0].num_nodes() | ||
self._all_nodes_set = ItemSet(num_nodes, names="seed_nodes") | ||
features = {} | ||
for name in legacy[0].ndata.keys(): | ||
tensor = legacy[0].ndata[name] | ||
if tensor.dim() == 1: | ||
tensor = tensor.view(-1, 1) | ||
features[("node", None, name)] = TorchBasedFeature(tensor) | ||
for name in legacy[0].edata.keys(): | ||
tensor = legacy[0].edata[name] | ||
if tensor.dim() == 1: | ||
tensor = tensor.view(-1, 1) | ||
features[("edge", None, name)] = TorchBasedFeature(tensor) | ||
self._feature = BasicFeatureStore(features) | ||
self._graph = from_dglgraph(legacy[0], is_homogeneous=True) | ||
self._dataset_name = legacy.name | ||
|
||
@property | ||
def tasks(self) -> List[Task]: | ||
"""Return the tasks.""" | ||
return self._tasks | ||
|
||
@property | ||
def graph(self) -> SamplingGraph: | ||
"""Return the graph.""" | ||
return self._graph | ||
|
||
@property | ||
def feature(self) -> BasicFeatureStore: | ||
"""Return the feature.""" | ||
return self._feature | ||
|
||
@property | ||
def dataset_name(self) -> str: | ||
"""Return the dataset name.""" | ||
return self._dataset_name | ||
|
||
@property | ||
def all_nodes_set(self) -> Union[ItemSet, ItemSetDict]: | ||
"""Return the itemset containing all nodes.""" | ||
return self._all_nodes_set |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.