In [1]:
import pickle 
import os 
import torch 
from collections import Counter 
from tqdm import tqdm 

with open(os.path.expanduser("~/dataset/OAG/raw/graph_CS.dict.pkl"), "rb") as fp:
    graph_info = pickle.load(fp) 
    
edge_list = graph_info['edge_list'] 
node_feature = graph_info['node_feature']
node_forward = graph_info['node_forward'] 
times = graph_info['times']

In [2]:
num_paper_nodes = len(node_feature["paper"])

num_paper_nodes

544244

In [3]:
paper_year_vec = torch.zeros(num_paper_nodes, dtype=torch.int64)

for etype in tqdm(graph_info['edge_list']['field']['paper']):
    for field_id in graph_info['edge_list']['field']['paper'][etype]: 
        for paper_id in graph_info['edge_list']['field']['paper'][etype][field_id]: 
            paper_year = graph_info['edge_list']['field']['paper'][etype][field_id][paper_id] 

            if paper_year_vec[paper_id] == 0: 
                paper_year_vec[paper_id] = paper_year 
            else: 
                assert paper_year_vec[paper_id] == paper_year 
                
(paper_year_vec > 0).sum(), (paper_year_vec > 0).float().mean() 

100%|██████████| 6/6 [00:56<00:00,  9.45s/it]


(tensor(544244), tensor(1.))

In [4]:
paper_pretrain_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_train_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_val_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)
paper_test_mask = torch.zeros(num_paper_nodes, dtype=torch.bool)

paper_pretrain_mask[paper_year_vec < 2014] = True 
paper_train_mask[(paper_year_vec >= 2014) & (paper_year_vec <= 2016)] = True 
paper_val_mask[paper_year_vec == 2017] = True 
paper_test_mask[paper_year_vec >= 2018] = True 

paper_pretrain_mask.float().mean(), \
    paper_train_mask.float().mean(), \
    paper_val_mask.float().mean(), \
    paper_test_mask.float().mean()

(tensor(0.6265), tensor(0.2087), tensor(0.0716), tensor(0.0933))

In [5]:
counter_result = sorted(Counter(paper_year_vec.tolist()).items()) 

counter_result[:10], counter_result[-10:]

([(1922, 1),
  (1930, 1),
  (1938, 1),
  (1946, 1),
  (1948, 1),
  (1950, 1),
  (1951, 2),
  (1952, 3),
  (1953, 1),
  (1954, 10)],
 [(2011, 30429),
  (2012, 33031),
  (2013, 35273),
  (2014, 36100),
  (2015, 39158),
  (2016, 38309),
  (2017, 38942),
  (2018, 34218),
  (2019, 16281),
  (2020, 254)])

In [6]:
edge_index_dict = dict() 

for src_ntype in tqdm(graph_info['edge_list']): 
    for dest_ntype in graph_info['edge_list'][src_ntype]:
        for etype in graph_info['edge_list'][src_ntype][dest_ntype]:
            edge_list = [] 
            
            for src_nid in graph_info['edge_list'][src_ntype][dest_ntype][etype]: 
                for dest_nid in graph_info['edge_list'][src_ntype][dest_ntype][etype][src_nid]: 
                    edge_list.append((src_nid, dest_nid))

            edge_index = torch.tensor(edge_list, dtype=torch.int64).T 
            
            edge_index_dict[(src_ntype, etype, dest_ntype)] = edge_index 

edge_index_dict.keys() 

100%|██████████| 5/5 [00:10<00:00,  2.17s/it]


dict_keys([('venue', 'PV_Conference', 'paper'), ('venue', 'PV_Journal', 'paper'), ('venue', 'PV_Repository', 'paper'), ('venue', 'PV_Patent', 'paper'), ('paper', 'rev_PV_Conference', 'venue'), ('paper', 'rev_PV_Journal', 'venue'), ('paper', 'rev_PV_Repository', 'venue'), ('paper', 'rev_PV_Patent', 'venue'), ('paper', 'PP_cite', 'paper'), ('paper', 'rev_PP_cite', 'paper'), ('paper', 'rev_PF_in_L0', 'field'), ('paper', 'rev_PF_in_L3', 'field'), ('paper', 'rev_PF_in_L1', 'field'), ('paper', 'rev_PF_in_L2', 'field'), ('paper', 'rev_PF_in_L5', 'field'), ('paper', 'rev_PF_in_L4', 'field'), ('paper', 'AP_write_last', 'author'), ('paper', 'AP_write_other', 'author'), ('paper', 'AP_write_first', 'author'), ('field', 'FF_in', 'field'), ('field', 'rev_FF_in', 'field'), ('field', 'PF_in_L0', 'paper'), ('field', 'PF_in_L3', 'paper'), ('field', 'PF_in_L1', 'paper'), ('field', 'PF_in_L2', 'paper'), ('field', 'PF_in_L5', 'paper'), ('field', 'PF_in_L4', 'paper'), ('affiliation', 'in', 'author'), ('auth

In [7]:
_edge_index_dict = dict() 

for (src_ntype, etype, dest_ntype), edge_index in edge_index_dict.items(): 
    if src_ntype == 'affiliation': 
        src_ntype = 'institution' 
    if dest_ntype == 'affiliation': 
        dest_ntype = 'institution'    
    
    if etype[:2].isupper(): 
        etype = etype[1] + etype[0] + etype[2:] 
    elif etype[:3] == 'rev': 
        etype = etype[4:]
        
    etype = etype.replace('_in_', '_')
    etype = etype.replace('_write_', '_')
        
    _edge_index_dict[(src_ntype, etype, dest_ntype)] = edge_index 
    
edge_index_dict = _edge_index_dict 

edge_index_dict.keys() 

dict_keys([('venue', 'VP_Conference', 'paper'), ('venue', 'VP_Journal', 'paper'), ('venue', 'VP_Repository', 'paper'), ('venue', 'VP_Patent', 'paper'), ('paper', 'PV_Conference', 'venue'), ('paper', 'PV_Journal', 'venue'), ('paper', 'PV_Repository', 'venue'), ('paper', 'PV_Patent', 'venue'), ('paper', 'PP_cite', 'paper'), ('paper', 'PF_L0', 'field'), ('paper', 'PF_L3', 'field'), ('paper', 'PF_L1', 'field'), ('paper', 'PF_L2', 'field'), ('paper', 'PF_L5', 'field'), ('paper', 'PF_L4', 'field'), ('paper', 'PA_last', 'author'), ('paper', 'PA_other', 'author'), ('paper', 'PA_first', 'author'), ('field', 'FF_in', 'field'), ('field', 'FP_L0', 'paper'), ('field', 'FP_L3', 'paper'), ('field', 'FP_L1', 'paper'), ('field', 'FP_L2', 'paper'), ('field', 'FP_L5', 'paper'), ('field', 'FP_L4', 'paper'), ('institution', 'in', 'author'), ('author', 'in', 'institution'), ('author', 'AP_last', 'paper'), ('author', 'AP_other', 'paper'), ('author', 'AP_first', 'paper')])

In [8]:
edge_index_dict[('institution', 'IA', 'author')] = edge_index_dict.pop(('institution', 'in', 'author'))
edge_index_dict[('author', 'AI', 'institution')] = edge_index_dict.pop(('author', 'in', 'institution'))

PP_edge_index = edge_index_dict.pop(('paper', 'PP_cite', 'paper')) 
PP_edge_index = torch.cat([PP_edge_index, torch.flip(PP_edge_index, dims=[0])], dim=-1) 
PP_edge_index = torch.unique(PP_edge_index, dim=-1) 
edge_index_dict[('paper', 'PP', 'paper')] = PP_edge_index 

{ k: v.shape for k, v in edge_index_dict.items() }

{('venue', 'VP_Conference', 'paper'): torch.Size([2, 296775]),
 ('venue', 'VP_Journal', 'paper'): torch.Size([2, 228062]),
 ('venue', 'VP_Repository', 'paper'): torch.Size([2, 19216]),
 ('venue', 'VP_Patent', 'paper'): torch.Size([2, 191]),
 ('paper', 'PV_Conference', 'venue'): torch.Size([2, 296775]),
 ('paper', 'PV_Journal', 'venue'): torch.Size([2, 228062]),
 ('paper', 'PV_Repository', 'venue'): torch.Size([2, 19216]),
 ('paper', 'PV_Patent', 'venue'): torch.Size([2, 191]),
 ('paper', 'PF_L0', 'field'): torch.Size([2, 544371]),
 ('paper', 'PF_L3', 'field'): torch.Size([2, 866423]),
 ('paper', 'PF_L1', 'field'): torch.Size([2, 1197205]),
 ('paper', 'PF_L2', 'field'): torch.Size([2, 2337525]),
 ('paper', 'PF_L5', 'field'): torch.Size([2, 202221]),
 ('paper', 'PF_L4', 'field'): torch.Size([2, 303541]),
 ('paper', 'PA_last', 'author'): torch.Size([2, 429392]),
 ('paper', 'PA_other', 'author'): torch.Size([2, 662167]),
 ('paper', 'PA_first', 'author'): torch.Size([2, 454913]),
 ('field',

In [9]:
import dgl 

hg = dgl.heterograph({k: tuple(v) for k, v in edge_index_dict.items()}) 
num_nodes_dict = { ntype: hg.num_nodes(ntype) for ntype in hg.ntypes } 

num_nodes_dict 

{'author': 510189,
 'field': 45717,
 'institution': 9079,
 'paper': 544244,
 'venue': 6934}

In [10]:
paper_feat_mat = torch.tensor(list(node_feature['paper']['emb']), dtype=torch.float32) 

paper_feat_mat.shape 

torch.Size([544244, 768])

In [14]:
PV_edge_index = edge_index_dict[('paper', 'PV_Journal', 'venue')] 

paper_venue_label_vec = torch.full(fill_value=-1, size=[num_paper_nodes], dtype=torch.int64)

venue_remap: dict[int, int] = dict() 

for paper_nid, venue_nid in zip(*PV_edge_index.tolist()):
    if venue_nid not in venue_remap: 
        venue_remap[venue_nid] = len(venue_remap) 

    label_id = venue_remap[venue_nid] 
    
    if paper_venue_label_vec[paper_nid] == -1: 
        paper_venue_label_vec[paper_nid] = label_id 
    else:
        assert paper_venue_label_vec[paper_nid] == label_id 

(paper_venue_label_vec > -1).float().mean(), \
    int(torch.max(paper_venue_label_vec)) + 1  

(tensor(0.4190), 3505)

In [18]:
PF_edge_index = edge_index_dict[('paper', 'PF_L1', 'field')] 

field_remap: dict[int, int] = dict() 

for paper_nid, field_nid in zip(*PF_edge_index.tolist()):
    if field_nid not in field_remap: 
        field_remap[field_nid] = len(field_remap) 
        
paper_field_label_mat = torch.zeros([num_paper_nodes, len(field_remap)], dtype=torch.bool)
        
for paper_nid, field_nid in zip(*PF_edge_index.tolist()):
    label_id = field_remap[field_nid] 
    
    paper_field_label_mat[paper_nid, label_id] = True 

paper_field_label_mat.shape, \
    paper_field_label_mat.float().mean() 

(torch.Size([544244, 275]), tensor(0.0080))

In [None]:
with open('/home/gh/dataset/OAG/OAG-CS-Venue/hg_full.dict.pkl', 'wb') as fp: 
    pickle.dump(
        dict(
            edge_index_dict = edge_index_dict, 
            num_nodes_dict = num_nodes_dict, 
            paper_feat = paper_feat_mat, 
            paper_year = paper_year_vec, 
            paper_pretrain_mask = paper_pretrain_mask, 
            paper_train_mask = paper_train_mask, 
            paper_val_mask = paper_val_mask, 
            paper_test_mask = paper_test_mask, 
            paper_venue_label = paper_venue_label_vec, 
            paper_field_label = paper_field_label_mat, 
        ), 
        fp, 
    )