In [1]:
import torch 
import numpy as np 
import torch_geometric 
import ogb 
from ogb.nodeproppred import NodePropPredDataset
import os 

dataset = NodePropPredDataset(name='ogbn-mag', root=os.path.expanduser('~/dataset/OGB/ogbn-mag/raw'))

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
graph, label = dataset[0] 

In [2]:
num_nodes_dict = graph['num_nodes_dict'] 
num_nodes_dict['field'] = num_nodes_dict.pop('field_of_study') 

num_nodes_dict 

{'author': 1134649, 'institution': 8740, 'paper': 736389, 'field': 59965}

In [3]:
edge_index_dict = graph['edge_index_dict'] 

edge_index_dict.keys() 

dict_keys([('author', 'affiliated_with', 'institution'), ('author', 'writes', 'paper'), ('paper', 'cites', 'paper'), ('paper', 'has_topic', 'field_of_study')])

In [4]:
edge_index_dict[('author', 'AI', 'institution')] = torch.tensor(edge_index_dict.pop(('author', 'affiliated_with', 'institution')), dtype=torch.int64) 
edge_index_dict[('author', 'AP', 'paper')] = torch.tensor(edge_index_dict.pop(('author', 'writes', 'paper')), dtype=torch.int64) 
edge_index_dict[('paper', 'PP', 'paper')] = torch.tensor(edge_index_dict.pop(('paper', 'cites', 'paper')), dtype=torch.int64) 
edge_index_dict[('paper', 'PF', 'field')] = torch.tensor(edge_index_dict.pop(('paper', 'has_topic', 'field_of_study')), dtype=torch.int64) 

edge_index_dict

{('author',
  'AI',
  'institution'): tensor([[      0,       1,       2,  ..., 1134645, 1134647, 1134648],
         [    845,     996,    3197,  ...,    5189,    4668,    4668]]),
 ('author',
  'AP',
  'paper'): tensor([[      0,       0,       0,  ..., 1134647, 1134648, 1134648],
         [  19703,  289285,  311768,  ...,  657395,  671118,  719594]]),
 ('paper',
  'PP',
  'paper'): tensor([[     0,      0,      0,  ..., 736388, 736388, 736388],
         [    88,  27449, 121051,  ..., 421711, 427339, 439864]]),
 ('paper',
  'PF',
  'field'): tensor([[     0,      0,      0,  ..., 736388, 736388, 736388],
         [   145,   2215,   3205,  ...,  21458,  22283,  31934]])}

In [5]:
paper_feat = torch.tensor(graph['node_feat_dict']['paper'], dtype=torch.float32) 

paper_feat.shape 

torch.Size([736389, 128])

In [6]:
paper_label = torch.tensor(label['paper'].reshape(-1), dtype=torch.int64) 

paper_label.shape 

torch.Size([736389])

In [7]:
train_mask = torch.zeros(len(paper_label), dtype=torch.bool)
val_mask = torch.zeros(len(paper_label), dtype=torch.bool)
test_mask = torch.zeros(len(paper_label), dtype=torch.bool)
train_mask[train_idx['paper']] = True 
val_mask[valid_idx['paper']] = True 
test_mask[test_idx['paper']] = True 

train_mask.sum(), val_mask.sum(), test_mask.sum() 

(tensor(629571), tensor(64879), tensor(41939))

In [8]:
paper_year = torch.tensor(graph['node_year']['paper'].reshape(-1), dtype=torch.int64) 

paper_year.shape 

torch.Size([736389])

In [9]:
import pickle  

with open(os.path.expanduser('~/dataset/OGB/ogbn-mag/processed/hg.dict.pkl'), 'wb') as fp: 
    pickle.dump(
        dict(
            num_nodes_dict = num_nodes_dict, 
            edge_index_dict = edge_index_dict, 
            paper_feat = paper_feat,
            paper_year = paper_year,  
            paper_label = paper_label, 
            paper_train_mask = train_mask,
            paper_val_mask = val_mask,
            paper_test_mask = test_mask,
        ), 
        fp, 
    )