In [6]:
import os 
import pickle 
import torch 

with open(os.path.expanduser('~/dataset/OAG/OAG-Engin/hg_full.dict.pkl'), 'rb') as fp: 
    graph_info_dict = pickle.load(fp) 
    
PV_edge_index = graph_info_dict['edge_index_dict'][('paper', 'PV_Journal', 'venue')]
paper_year_vec = graph_info_dict['paper_year'] 
num_paper_nodes = len(paper_year_vec) 

PV_edge_index.shape, paper_year_vec.shape 

(torch.Size([2, 294050]), torch.Size([370624]))

In [7]:
venue_id_map: dict[int, int] = dict()

for P_nid, V_nid in zip(*PV_edge_index.tolist()): 
    if V_nid not in venue_id_map: 
        venue_id_map[V_nid] = len(venue_id_map) 
        
num_venue_nodes = len(venue_id_map) 

num_venue_nodes

3957

In [8]:
paper_label_vec = torch.full(fill_value=-1, size=[num_paper_nodes], dtype=torch.int64) 

for P_nid, V_nid in zip(*PV_edge_index.tolist()): 
    paper_label_vec[P_nid] = venue_id_map[V_nid]
    
assert int(torch.max(paper_label_vec)) + 1 == num_venue_nodes 

(paper_label_vec > -1).float().mean() 

tensor(0.7934)

In [9]:
paper_train_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_val_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_test_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)

for P_nid, P_year in enumerate(paper_year_vec.tolist()): 
    if paper_label_vec[P_nid] > -1:
        if P_year < 2014:
            pass  
        elif P_year >= 2014 and P_year <= 2016: 
            paper_train_mask[P_nid] = True 
        elif P_year == 2017: 
            paper_val_mask[P_nid] = True 
        else:
            paper_test_mask[P_nid] = True 
            
paper_pretrain_mask = ~(paper_train_mask | paper_val_mask | paper_test_mask)

paper_pretrain_mask.sum(), \
    paper_train_mask.sum(), \
    paper_val_mask.sum(), \
    paper_test_mask.sum(), \
    paper_pretrain_mask.float().mean(), \
    paper_train_mask.float().mean(), \
    paper_val_mask.float().mean(), \
    paper_test_mask.float().mean()

(tensor(251915),
 tensor(84103),
 tensor(28524),
 tensor(6082),
 tensor(0.6797),
 tensor(0.2269),
 tensor(0.0770),
 tensor(0.0164))

In [10]:
graph_info_dict['paper_label'] = paper_label_vec  
graph_info_dict['paper_pretrain_mask'] = paper_pretrain_mask 
graph_info_dict['paper_train_mask'] = paper_train_mask 
graph_info_dict['paper_val_mask'] = paper_val_mask 
graph_info_dict['paper_test_mask'] = paper_test_mask 

with open(os.path.expanduser('~/dataset/OAG/OAG-Engin/hg_venue.dict.pkl'), 'wb') as fp: 
    pickle.dump(graph_info_dict, fp) 