In [1]:
import math
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F 
import dgl

from ogb.lsc import DglPCQM4MDataset
from ogb.utils import smiles2graph

from timeit import default_timer
from tqdm import trange

from typing import Union

Using backend: pytorch


In [2]:
dataset = DglPCQM4MDataset(root='/home/ksadowski/datasets', smiles2graph=smiles2graph)

In [3]:
split_dict = dataset.get_idx_split()

train_idx = split_dict['train']
valid_idx = split_dict['valid']
test_idx = split_dict['test']

In [20]:
for i in trange(len(valid_idx)):
    if torch.isnan(dataset[valid_idx[i]][1]) == True:
        print(f'{i}: NaN')


100%|██████████| 380670/380670 [00:03<00:00, 99587.51it/s]


In [22]:
torch.save(train_idx, './data/train_idx.pt')
torch.save(valid_idx, './data/val_idx.pt')
torch.save(test_idx, './data/test_idx.pt')

In [32]:
max_node = 0
max_edge = 0

min_node = 0
min_edge = 0

for data in dataset:
    g = data[0]

    max_g_node = torch.max(g.ndata['feat'])
    min_g_node = torch.min(g.ndata['feat'])
    
    if max_g_node > max_node:
        max_node = max_g_node

    if min_g_node < min_node:
        min_node = min_g_node

    if len(g.edata['feat']):
        max_g_edge = torch.max(g.edata['feat'])
        min_g_edge = torch.min(g.edata['feat'])
        
        if max_g_edge > max_edge:
            max_edge = max_g_edge

        if min_g_edge < min_edge:
            min_edge = min_g_edge

print(max_node)
print(max_edge)


tensor(52)
tensor(3)


In [3]:
class ProcessedPCQM4M(dgl.data.DGLDataset):
    def __init__(self, ogb_dataset: dgl.data.DGLDataset, normalize: bool = False):
        self.ogb_dataset = ogb_dataset
        self.normalize = normalize
        self.graphs = []
        self.line_graphs = []
        self.labels = []
        super().__init__(name='processed_PCQM4M')

    def process(self):
        for i in trange(len(self.ogb_dataset)):
            g = self.ogb_dataset[i][0].add_self_loop()
            lg = dgl.line_graph(g, backtracking=False).add_self_loop()

            g.ndata['feat'] = g.ndata['feat'].float()
            g.edata['feat'] = g.edata['feat'].float()

            if self.normalize:
                g.ndata['feat'] /= 52
                g.edata['feat'] /= 3

            self.graphs.append(g)
            self.line_graphs.append(lg)
            self.labels.append(self.ogb_dataset[i][1])

    def __getitem__(self, idx: Union[int, torch.Tensor]):
        if isinstance(idx, int):
            return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
        elif torch.is_tensor(idx) and idx.dtype == torch.long:
            if idx.dim() == 0:
                return self.graphs[idx], self.line_graphs[idx], self.labels[idx]
            elif idx.dim() == 1:
                return dgl.data.utils.Subset(self, idx.cpu())

    def __len__(self):
        return len(self.graphs)

In [4]:
processed_dataset_norm = ProcessedPCQM4M(dataset, normalize=True)

for i in trange(len(processed_dataset_norm)):
    if not torch.isnan(processed_dataset_norm[i][2]):
        assert processed_dataset_norm[i][2] == dataset[i][1]

print('Finished checking labels ordering.')

labels = {f'{i}': processed_dataset_norm[i][2] for i in range(len(processed_dataset_norm))}

# dgl.data.utils.save_graphs('./data/molecules_norm.bin', processed_dataset_norm.graphs, labels)
dgl.data.utils.save_graphs('./data/molecules_lg.bin', processed_dataset_norm.line_graphs, labels)

100%|██████████| 3803453/3803453 [46:08<00:00, 1373.65it/s]
100%|██████████| 3803453/3803453 [00:27<00:00, 138155.02it/s]
Finished checking labels ordering.


In [4]:
processed_dataset = ProcessedPCQM4M(dataset, normalize=False)

for i in trange(len(processed_dataset)):
    if not torch.isnan(processed_dataset[i][2]):
        assert processed_dataset[i][2] == dataset[i][1]

print('Finished checking labels ordering.')

labels = {f'{i}': processed_dataset[i][2] for i in range(len(processed_dataset))}

dgl.data.utils.save_graphs('./data/molecules.bin', processed_dataset.graphs, labels)

100%|██████████| 3803453/3803453 [42:52<00:00, 1478.42it/s]
100%|██████████| 3803453/3803453 [00:26<00:00, 146116.26it/s]
Finished checking labels ordering.


In [5]:
for i in trange(len(dataset)):
    g = dataset[i][0]
    lg = dgl.line_graph(g, backtracking=False).add_self_loop()

  1%|          | 29348/3803453 [00:08<18:03, 3482.74it/s]