In [1]:
import pickle 
import os 
import torch 
from collections import Counter 
from tqdm import tqdm 

with open(os.path.expanduser("~/dataset/OAG/raw/PT-HGNN/graph_CS_20190919.dict.pkl"), "rb") as fp:
    graph_info = pickle.load(fp) 
    
edge_list = graph_info['edge_list'] 
node_feature = graph_info['node_feature']
node_forward = graph_info['node_forward'] 
times = graph_info['times']

In [2]:
num_paper_nodes = len(node_feature["paper"])

num_paper_nodes

546704

In [3]:
paper_year_vec = torch.zeros(num_paper_nodes, dtype=torch.int64)

for etype in tqdm(graph_info['edge_list']['field']['paper']):
    for field_id in graph_info['edge_list']['field']['paper'][etype]: 
        for paper_id in graph_info['edge_list']['field']['paper'][etype][field_id]: 
            paper_year = graph_info['edge_list']['field']['paper'][etype][field_id][paper_id] 

            if paper_year_vec[paper_id] == 0: 
                paper_year_vec[paper_id] = paper_year 
            else: 
                assert paper_year_vec[paper_id] == paper_year 
                
(paper_year_vec > 0).sum(), (paper_year_vec > 0).float().mean() 

100%|██████████| 6/6 [00:57<00:00,  9.57s/it]


(tensor(546704), tensor(1.))

In [4]:
counter_result = sorted(Counter(paper_year_vec.tolist()).items()) 

counter_result[:10], counter_result[-10:]

([(1922, 1),
  (1930, 1),
  (1938, 1),
  (1946, 1),
  (1948, 1),
  (1950, 1),
  (1951, 2),
  (1952, 3),
  (1953, 1),
  (1954, 10)],
 [(2011, 30523),
  (2012, 33140),
  (2013, 35382),
  (2014, 36214),
  (2015, 39314),
  (2016, 38459),
  (2017, 39072),
  (2018, 34293),
  (2019, 16308),
  (2020, 255)])

In [5]:
edge_index_dict = dict() 

for src_ntype in tqdm(graph_info['edge_list']): 
    for dest_ntype in graph_info['edge_list'][src_ntype]:
        for etype in graph_info['edge_list'][src_ntype][dest_ntype]:
            edge_list = [] 
            
            for src_nid in graph_info['edge_list'][src_ntype][dest_ntype][etype]: 
                for dest_nid in graph_info['edge_list'][src_ntype][dest_ntype][etype][src_nid]: 
                    edge_list.append((src_nid, dest_nid))

            edge_index = torch.tensor(edge_list, dtype=torch.int64).T 
            
            edge_index_dict[(src_ntype, etype, dest_ntype)] = edge_index 

edge_index_dict.keys() 

100%|██████████| 5/5 [00:10<00:00,  2.09s/it]


dict_keys([('venue', 'PV_Conference', 'paper'), ('venue', 'PV_Journal', 'paper'), ('venue', 'PV_Repository', 'paper'), ('venue', 'PV_Patent', 'paper'), ('paper', 'rev_PV_Conference', 'venue'), ('paper', 'rev_PV_Journal', 'venue'), ('paper', 'rev_PV_Repository', 'venue'), ('paper', 'rev_PV_Patent', 'venue'), ('paper', 'PP_cite', 'paper'), ('paper', 'rev_PP_cite', 'paper'), ('paper', 'rev_PF_in_L0', 'field'), ('paper', 'rev_PF_in_L3', 'field'), ('paper', 'rev_PF_in_L1', 'field'), ('paper', 'rev_PF_in_L2', 'field'), ('paper', 'rev_PF_in_L5', 'field'), ('paper', 'rev_PF_in_L4', 'field'), ('paper', 'AP_write_last', 'author'), ('paper', 'AP_write_other', 'author'), ('paper', 'AP_write_first', 'author'), ('field', 'FF_in', 'field'), ('field', 'rev_FF_in', 'field'), ('field', 'PF_in_L0', 'paper'), ('field', 'PF_in_L3', 'paper'), ('field', 'PF_in_L1', 'paper'), ('field', 'PF_in_L2', 'paper'), ('field', 'PF_in_L5', 'paper'), ('field', 'PF_in_L4', 'paper'), ('affiliation', 'in', 'author'), ('auth

In [6]:
_edge_index_dict = dict() 

for (src_ntype, etype, dest_ntype), edge_index in edge_index_dict.items(): 
    if src_ntype == 'affiliation': 
        src_ntype = 'institution' 
    if dest_ntype == 'affiliation': 
        dest_ntype = 'institution'    
    
    if etype[:2].isupper(): 
        etype = etype[1] + etype[0] + etype[2:] 
    elif etype[:3] == 'rev': 
        etype = etype[4:]
        
    etype = etype.replace('_in_', '_')
    etype = etype.replace('_write_', '_')
        
    _edge_index_dict[(src_ntype, etype, dest_ntype)] = edge_index 
    
edge_index_dict = _edge_index_dict 

edge_index_dict.keys() 

dict_keys([('venue', 'VP_Conference', 'paper'), ('venue', 'VP_Journal', 'paper'), ('venue', 'VP_Repository', 'paper'), ('venue', 'VP_Patent', 'paper'), ('paper', 'PV_Conference', 'venue'), ('paper', 'PV_Journal', 'venue'), ('paper', 'PV_Repository', 'venue'), ('paper', 'PV_Patent', 'venue'), ('paper', 'PP_cite', 'paper'), ('paper', 'PF_L0', 'field'), ('paper', 'PF_L3', 'field'), ('paper', 'PF_L1', 'field'), ('paper', 'PF_L2', 'field'), ('paper', 'PF_L5', 'field'), ('paper', 'PF_L4', 'field'), ('paper', 'PA_last', 'author'), ('paper', 'PA_other', 'author'), ('paper', 'PA_first', 'author'), ('field', 'FF_in', 'field'), ('field', 'FP_L0', 'paper'), ('field', 'FP_L3', 'paper'), ('field', 'FP_L1', 'paper'), ('field', 'FP_L2', 'paper'), ('field', 'FP_L5', 'paper'), ('field', 'FP_L4', 'paper'), ('institution', 'in', 'author'), ('author', 'in', 'institution'), ('author', 'AP_last', 'paper'), ('author', 'AP_other', 'paper'), ('author', 'AP_first', 'paper')])

In [7]:
edge_index_dict[('institution', 'IA', 'author')] = edge_index_dict.pop(('institution', 'in', 'author'))
edge_index_dict[('author', 'AI', 'institution')] = edge_index_dict.pop(('author', 'in', 'institution'))

PP_edge_index = edge_index_dict.pop(('paper', 'PP_cite', 'paper')) 
PP_edge_index = torch.cat([PP_edge_index, torch.flip(PP_edge_index, dims=[0])], dim=-1) 
PP_edge_index = torch.unique(PP_edge_index, dim=-1) 
edge_index_dict[('paper', 'PP', 'paper')] = PP_edge_index 

{ k: v.shape for k, v in edge_index_dict.items() }

{('venue', 'VP_Conference', 'paper'): torch.Size([2, 298075]),
 ('venue', 'VP_Journal', 'paper'): torch.Size([2, 229104]),
 ('venue', 'VP_Repository', 'paper'): torch.Size([2, 19318]),
 ('venue', 'VP_Patent', 'paper'): torch.Size([2, 207]),
 ('paper', 'PV_Conference', 'venue'): torch.Size([2, 298075]),
 ('paper', 'PV_Journal', 'venue'): torch.Size([2, 229104]),
 ('paper', 'PV_Repository', 'venue'): torch.Size([2, 19318]),
 ('paper', 'PV_Patent', 'venue'): torch.Size([2, 207]),
 ('paper', 'PF_L0', 'field'): torch.Size([2, 546831]),
 ('paper', 'PF_L3', 'field'): torch.Size([2, 869729]),
 ('paper', 'PF_L1', 'field'): torch.Size([2, 1202487]),
 ('paper', 'PF_L2', 'field'): torch.Size([2, 2347239]),
 ('paper', 'PF_L5', 'field'): torch.Size([2, 203338]),
 ('paper', 'PF_L4', 'field'): torch.Size([2, 304912]),
 ('paper', 'PA_last', 'author'): torch.Size([2, 431027]),
 ('paper', 'PA_other', 'author'): torch.Size([2, 664266]),
 ('paper', 'PA_first', 'author'): torch.Size([2, 456903]),
 ('field',

In [8]:
PV_edge_index = torch.cat([
    edge_index_dict[('paper', 'PV_Conference', 'venue')], 
    edge_index_dict[('paper', 'PV_Journal', 'venue')], 
    edge_index_dict[('paper', 'PV_Repository', 'venue')], 
    edge_index_dict[('paper', 'PV_Patent', 'venue')], 
], dim=-1)

PV_edge_index = torch.unique(PV_edge_index, dim=-1) 

VP_edge_index = torch.flip(PV_edge_index, dims=[0]) 

edge_index_dict[('paper', 'PV', 'venue')] = PV_edge_index 
edge_index_dict[('venue', 'VP', 'paper')] = VP_edge_index 

In [9]:
PF_edge_index = torch.cat([
    edge_index_dict[('paper', 'PF_L5', 'field')], 
    edge_index_dict[('paper', 'PF_L4', 'field')], 
    edge_index_dict[('paper', 'PF_L3', 'field')], 
    edge_index_dict[('paper', 'PF_L2', 'field')], 
], dim=-1)

PF_edge_index = torch.unique(PF_edge_index, dim=-1) 

FP_edge_index = torch.flip(PF_edge_index, dims=[0]) 

edge_index_dict[('paper', 'PF', 'field')] = PF_edge_index 
edge_index_dict[('field', 'FP', 'paper')] = FP_edge_index 

In [10]:
PA_edge_index = torch.cat([
    edge_index_dict[('paper', 'PA_last', 'author')], 
    edge_index_dict[('paper', 'PA_other', 'author')], 
    edge_index_dict[('paper', 'PA_first', 'author')], 
], dim=-1)

PA_edge_index = torch.unique(PA_edge_index, dim=-1) 

AP_edge_index = torch.flip(PA_edge_index, dims=[0]) 

edge_index_dict[('paper', 'PA', 'author')] = PA_edge_index 
edge_index_dict[('author', 'AP', 'paper')] = AP_edge_index 

{ k: v.shape for k, v in edge_index_dict.items() }

{('venue', 'VP_Conference', 'paper'): torch.Size([2, 298075]),
 ('venue', 'VP_Journal', 'paper'): torch.Size([2, 229104]),
 ('venue', 'VP_Repository', 'paper'): torch.Size([2, 19318]),
 ('venue', 'VP_Patent', 'paper'): torch.Size([2, 207]),
 ('paper', 'PV_Conference', 'venue'): torch.Size([2, 298075]),
 ('paper', 'PV_Journal', 'venue'): torch.Size([2, 229104]),
 ('paper', 'PV_Repository', 'venue'): torch.Size([2, 19318]),
 ('paper', 'PV_Patent', 'venue'): torch.Size([2, 207]),
 ('paper', 'PF_L0', 'field'): torch.Size([2, 546831]),
 ('paper', 'PF_L3', 'field'): torch.Size([2, 869729]),
 ('paper', 'PF_L1', 'field'): torch.Size([2, 1202487]),
 ('paper', 'PF_L2', 'field'): torch.Size([2, 2347239]),
 ('paper', 'PF_L5', 'field'): torch.Size([2, 203338]),
 ('paper', 'PF_L4', 'field'): torch.Size([2, 304912]),
 ('paper', 'PA_last', 'author'): torch.Size([2, 431027]),
 ('paper', 'PA_other', 'author'): torch.Size([2, 664266]),
 ('paper', 'PA_first', 'author'): torch.Size([2, 456903]),
 ('field',

In [11]:
import dgl 

hg = dgl.heterograph({k: tuple(v) for k, v in edge_index_dict.items()}) 
num_nodes_dict = { ntype: hg.num_nodes(ntype) for ntype in hg.ntypes } 

num_nodes_dict 

{'author': 511122,
 'field': 45775,
 'institution': 9090,
 'paper': 546704,
 'venue': 6946}

In [12]:
paper_feat_mat = torch.tensor(list(node_feature['paper']['emb']), dtype=torch.float32) 

paper_feat_mat.shape 

torch.Size([546704, 768])

In [13]:
with open(os.path.expanduser('~/dataset/OAG/OAG-CS/hg_full.dict.pkl'), 'wb') as fp: 
    pickle.dump(
        dict(
            edge_index_dict = edge_index_dict, 
            num_nodes_dict = num_nodes_dict, 
            paper_feat = paper_feat_mat, 
            paper_year = paper_year_vec, 
        ), 
        fp, 
    )