In [None]:
import os
import os.path as osp

import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid

from ogb.nodeproppred import PygNodePropPredDataset


def prep_data(dataset_name:str, K:int):
    """ standardize format of data object """
    possible_datasets = ['cora', 'pubmed', 'products', 'arxiv']
    dataset_name = dataset_name.lower()
    assert dataset_name in possible_datasets, f'Dataset {dataset_name} not available'

    # download data 
    if dataset_name=='arixv':
        transform = T.Compose([
            T.NormalizeFeatures(),
            T.ToUndirected(),
            T.AddSelfLoops(),
            T.SIGN(K)
        ])
    else:
        transform = T.Compose([
            T.NormalizeFeatures(),
            T.SIGN(K)
        ])

    if dataset_name in ['arxiv','products']:
        dataset = PygNodePropPredDataset(
            f'ogbn-{dataset_name}',
            root=path,
            transform=transform
            )
    else:
        dataset = Planetoid(
            root=path,
            name=dataset_name.title(),
            transform=transform,
            split='full'
            )

    # extract relevant information
    data = dataset[0]
    data.dataset_name = dataset_name.lower()
    data.num_classes = dataset.num_classes
    data.n_id = torch.arange(data.num_nodes)  # global node id

    # standardize mask -- node idx, not bool mask
    if hasattr(dataset, 'get_idx_split'):
        masks = dataset.get_idx_split()
        data.train_mask = masks['train']
        data.val_mask = masks['valid']
        data.test_mask = masks['test']

        data.y = data.y.flatten()
    else:
        data.train_mask = torch.where(data.train_mask)[0]
        data.val_mask = torch.where(data.val_mask)[0]
        data.test_mask = torch.where(data.test_mask)[0]

    return data


# create directory
folder_path = osp.join(os.getcwd(), 'data')

if not osp.exists(folder_path):
    os.makedirs(folder_path)
    print("Directory '% s' created" % folder_path)


In [None]:
DATASET = 'cora'

path = osp.join(folder_path, DATASET)

if not osp.exists(path):
    os.makedirs(path)

K = 5
data = prep_data(DATASET, K)

for i in range(K,-1,-1):
    filename = osp.join(path, f'{DATASET}_sign_k{i}.pth')
    torch.save(data, filename)
    del data[f'x{i}']
del data 


In [None]:
DATASET = 'pubmed'

path = osp.join(folder_path, DATASET)

if not osp.exists(path):
    os.makedirs(path)

K = 5
data = prep_data(DATASET, K)

for i in range(K,-1,-1):
    filename = osp.join(path, f'{DATASET}_sign_k{i}.pth')
    torch.save(data, filename)
    del data[f'x{i}']
del data 


In [None]:
DATASET = 'arxiv'

path = osp.join(folder_path, DATASET)

if not osp.exists(path):
    os.makedirs(path)

K = 5
data = prep_data(DATASET, K)

for i in range(K,-1,-1):
    filename = osp.join(path, f'{DATASET}_sign_k{i}.pth')
    torch.save(data, filename)
    del data[f'x{i}']
del data 


In [None]:
DATASET = 'products'

path = osp.join(folder_path, DATASET)

if not osp.exists(path):
    os.makedirs(path)

K = 5
data = prep_data(DATASET, K)

for i in range(K, -1, -1):
    filename = osp.join(path, f'{DATASET}_sign_k{i}.pth')
    torch.save(data, filename)
    del data[f'x{i}']
del data


In [None]:
data = torch.load('data/products/products_sign_k6.pth')

In [None]:
print(data.num_nodes)
print(data.num_edges)
print(data.num_node_features)