In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl

from torch.utils.data import DataLoader

from ogb.graphproppred import DglGraphPropPredDataset, collate_dgl

from timeit import default_timer
from typing import Union
from tqdm import trange

In [2]:
dataset = DglGraphPropPredDataset(root='/home/ksadowski/datasets', name='ogbg-molhiv')

Downloading http://snap.stanford.edu/ogb/data/graphproppred/csv_mol_download/hiv.zip
Downloaded 0.00 GB: 100%|██████████| 3/3 [00:07<00:00,  2.61s/it]
Extracting /home/ksadowski/datasets/hiv.zip
Loading necessary files...
This might take a while.
 38%|███▊      | 15507/41127 [00:00<00:00, 155049.08it/s]Processing graphs...
100%|██████████| 41127/41127 [00:00<00:00, 146789.91it/s]
  4%|▍         | 1793/41127 [00:00<00:04, 8963.63it/s]Converting graphs into DGL objects...
100%|██████████| 41127/41127 [00:05<00:00, 7325.79it/s]
Saving...


In [12]:
len(dataset)

41127

In [22]:
class ProcessedMolhiv(dgl.data.DGLDataset):
    def __init__(self, ogb_dataset: dgl.data.DGLDataset):
        self._ogb_dataset = ogb_dataset
        self.graphs = []
        self.line_graphs = []
        self.labels = []
        super().__init__(name='processed_PCQM4M')

    def process(self):
        for i in trange(len(self._ogb_dataset)):
            g = self._ogb_dataset[i][0].add_self_loop()
            lg = dgl.line_graph(g, backtracking=False).add_self_loop()

            g.ndata['feat'] = g.ndata['feat'].float()
            g.edata['feat'] = g.edata['feat'].float()

            self.graphs.append(g)
            self.line_graphs.append(lg)
            self.labels.append(self._ogb_dataset[i][1])

    def __getitem__(self, idx: Union[int, torch.Tensor]):
        if isinstance(idx, int):
            return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
        elif torch.is_tensor(idx) and idx.dtype == torch.long:
            if idx.dim() == 0:
                return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
            elif idx.dim() == 1:
                return dgl.data.utils.Subset(self, idx.cpu())

    def __len__(self):
        return len(self.graphs)

processed_dataset = ProcessedMolhiv(dataset)

100%|██████████| 41127/41127 [00:28<00:00, 1449.64it/s]


In [29]:
processed_dataset.graphs[0]

Graph(num_nodes=19, num_edges=59,
      ndata_schemes={'feat': Scheme(shape=(9,), dtype=torch.float32)}
      edata_schemes={'feat': Scheme(shape=(3,), dtype=torch.float32)})

In [30]:
dataset.eval_metric

'rocauc'