In [1]:
import pickle 

with open("/home/gh/dataset/OAG/raw/graph_CS.dict.pkl", "rb") as fp:
    graph_info = pickle.load(fp) 
    
edge_list = graph_info['edge_list'] 
node_feature = graph_info['node_feature']
node_forward = graph_info['node_forward'] 
times = graph_info['times']

In [2]:
num_paper_nodes = len(node_feature["paper"])

num_paper_nodes

544244

In [3]:
import torch 
from collections import Counter 
from tqdm import tqdm 

paper_label_list = list(edge_list["venue"]["paper"]["PV_Journal"].keys())
paper_label_vec = torch.full(fill_value=-1, size=[num_paper_nodes], dtype=torch.int64) 

paper_pretrain_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_train_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_val_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_test_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)

for paper_id in tqdm(edge_list["paper"]["venue"]["rev_PV_Journal"]):
    for venue_id in edge_list["paper"]["venue"]["rev_PV_Journal"][paper_id]:
        year = int(edge_list["paper"]["venue"]["rev_PV_Journal"][paper_id][venue_id]) 

        if year < 2014: 
            paper_pretrain_mask[paper_id] = True 
        elif 2014 <= year <= 2016: 
            paper_train_mask[paper_id] = True 
        elif year == 2017: 
            paper_val_mask[paper_id] = True 
        elif year >= 2018: 
            paper_test_mask[paper_id] = True 
        else:
            raise AssertionError 
        
        paper_label_vec[paper_id] = paper_label_list.index(venue_id)
        
paper_pretrain_mask.float().mean(), \
    paper_train_mask.float().mean(), \
    paper_val_mask.float().mean(), \
    paper_test_mask.float().mean()

100%|██████████| 228062/228062 [00:05<00:00, 43498.37it/s]


(tensor(132552), tensor(48799), tensor(19494), tensor(27217))

In [5]:
int(torch.max(paper_label_vec)) + 1, \
    torch.sum(paper_label_vec >= 0)

(3505, tensor(228062))

In [6]:
edge_index_dict: dict[tuple[str, str, str], torch.Tensor] = dict() 

for dtype in tqdm(edge_list):
    for src_ntype in edge_list[dtype]:
        for rel in edge_list[dtype][src_ntype]:
            if rel != "PV_Journal" and not rel.startswith("rev_"):
                src_list = []
                dst_list = []
                for dst in edge_list[dtype][src_ntype][rel]:
                    for src in edge_list[dtype][src_ntype][rel][dst]:
                        src_list.append(src)
                        dst_list.append(dst)

                edge_index = torch.tensor([src_list, dst_list], dtype=torch.int64) 
                edge_index_dict[(src_ntype, rel, dtype)] = edge_index 
                
edge_index_dict[('author', 'AI', 'institution')] = edge_index_dict.pop(('author', 'in', 'affiliation'))

{ k: v.shape for k, v in edge_index_dict.items() }

100%|██████████| 5/5 [00:04<00:00,  1.21it/s]


{('paper', 'PV_Conference', 'venue'): torch.Size([2, 296775]),
 ('paper', 'PV_Repository', 'venue'): torch.Size([2, 19216]),
 ('paper', 'PV_Patent', 'venue'): torch.Size([2, 191]),
 ('paper', 'PP_cite', 'paper'): torch.Size([2, 5796354]),
 ('author', 'AP_write_last', 'paper'): torch.Size([2, 429392]),
 ('author', 'AP_write_other', 'paper'): torch.Size([2, 662167]),
 ('author', 'AP_write_first', 'paper'): torch.Size([2, 454913]),
 ('field', 'FF_in', 'field'): torch.Size([2, 262526]),
 ('paper', 'PF_in_L0', 'field'): torch.Size([2, 544371]),
 ('paper', 'PF_in_L3', 'field'): torch.Size([2, 866423]),
 ('paper', 'PF_in_L1', 'field'): torch.Size([2, 1197205]),
 ('paper', 'PF_in_L2', 'field'): torch.Size([2, 2337525]),
 ('paper', 'PF_in_L5', 'field'): torch.Size([2, 202221]),
 ('paper', 'PF_in_L4', 'field'): torch.Size([2, 303541]),
 ('author', 'AI', 'institution'): torch.Size([2, 612872])}

In [9]:
import dgl 

hg = dgl.heterograph({k: tuple(v) for k, v in edge_index_dict.items()}) 
num_nodes_dict = { ntype: hg.num_nodes(ntype) for ntype in hg.ntypes } 

num_nodes_dict 

{'author': 510189,
 'field': 45717,
 'institution': 9079,
 'paper': 544244,
 'venue': 6933}

In [10]:
paper_feat = torch.tensor(list(node_feature['paper']['emb']), dtype=torch.float32) 

paper_feat.shape 

torch.Size([544244, 768])

In [None]:
with open('/home/gh/dataset/OAG/OAG-CS-Venue/hg.dict.pkl', 'wb') as fp: 
    pickle.dump(
        dict(
            edge_index_dict = edge_index_dict, 
            num_nodes_dict = num_nodes_dict, 
            paper_feat = paper_feat, 
            paper_label = paper_label_vec,  
            paper_pretrain_mask = paper_pretrain_mask,
            paper_train_mask = paper_train_mask,
            paper_val_mask = paper_val_mask,
            paper_test_mask = paper_test_mask,
        ), 
        fp, 
    )