-
Notifications
You must be signed in to change notification settings - Fork 334
/
ogb.py
46 lines (35 loc) · 1.09 KB
/
ogb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
from spektral.data import Dataset, Graph
from spektral.utils import sparse
class OGB(Dataset):
"""
Wrapper for datasets from the [Open Graph Benchmark (OGB)](https://ogb.stanford.edu/).
**Arguments**
- `dataset`: an OGB library-agnostic dataset.
"""
def __init__(self, dataset, **kwargs):
self.dataset = dataset
super().__init__(**kwargs)
def read(self):
if len(self.dataset) > 1:
return [Graph(*_elem_to_numpy(elem)) for elem in self.dataset]
else:
# OGB crashed if we try to iterate over a NodePropPredDataset
return [Graph(*_elem_to_numpy(self.dataset[0]))]
def _elem_to_numpy(elem):
graph, label = elem
n = graph["num_nodes"]
x = graph["node_feat"]
row, col = graph["edge_index"]
e = graph["edge_feat"]
a_e = sparse.edge_index_to_matrix(
edge_index=np.array((row, col)).T,
edge_weight=np.ones_like(row),
edge_features=e,
shape=(n, n),
)
if e is None:
a = a_e
else:
a, e = a_e
return x, a, e, label