In [1]:
import csv 
import os 
import json 
import numpy as np 
import traceback 
from tqdm import tqdm 

with open(os.path.expanduser('~/dataset/HGB/raw/DBLP/node.dat'), 'r', encoding='utf-8') as fp:
    reader = csv.DictReader(fp, fieldnames=['nid', 'name', 'node_type', 'feat'], delimiter='\t') 

    author_feat_list = [] 
    paper_feat_list = [] 
    term_feat_list = [] 
    venue_feat_list = [] 
    
    nid_map: dict[int, tuple[str, int]] = dict() 

    for row in tqdm(reader):
        nid = int(row['nid']) 
        name = row['name'].strip() 
        ntype = int(row['node_type']) 

        try:
            feat_str = row['feat'].strip() 
            feat = np.array(json.loads(f"[{feat_str}]"), dtype=np.float32)  
        except Exception:
            feat = None  
            
        if ntype == 0: 
            nid_map[nid] = ('author', len(author_feat_list)) 
            author_feat_list.append(feat)
        elif ntype == 1: 
            nid_map[nid] = ('paper', len(paper_feat_list)) 
            paper_feat_list.append(feat)
        elif ntype == 2: 
            nid_map[nid] = ('term', len(term_feat_list)) 
            term_feat_list.append(feat) 
        elif ntype == 3: 
            nid_map[nid] = ('venue', len(venue_feat_list)) 
            venue_feat_list.append(feat)
        else:
            raise AssertionError 
            
author_feat = np.stack(author_feat_list) 
paper_feat = np.stack(paper_feat_list) 
term_feat = np.stack(term_feat_list) 
assert all(x is None for x in venue_feat_list) 
    
author_feat.shape, \
paper_feat.shape, \
term_feat.shape, \
len(venue_feat_list) 

26128it [00:07, 3382.60it/s]


((4057, 334), (14328, 4231), (7723, 50), 20)

In [2]:
with open(os.path.expanduser('~/dataset/HGB/raw/DBLP/link.dat'), 'r', encoding='utf-8') as fp:
    reader = csv.DictReader(fp, fieldnames=['src_nid', 'dest_nid', 'etype', 'score'], delimiter='\t') 
    
    author_paper_edge_list = []
    paper_term_edge_list = []
    paper_venue_edge_list = []
    paper_author_edge_list = []
    term_paper_edge_list = []
    venue_paper_edge_list = []
    
    for row in tqdm(reader):
        src_nid = int(row['src_nid'])
        dest_nid = int(row['dest_nid']) 
        etype = int(row['etype'])
        score = float(row['score']) 
        assert score == 1. 
        
        if etype == 0:
            assert nid_map[src_nid][0] == 'author' and nid_map[dest_nid][0] == 'paper'
            author_paper_edge_list.append((nid_map[src_nid][1], nid_map[dest_nid][1])) 
        elif etype == 1:
            assert nid_map[src_nid][0] == 'paper' and nid_map[dest_nid][0] == 'term'
            paper_term_edge_list.append((nid_map[src_nid][1], nid_map[dest_nid][1])) 
        elif etype == 2:
            assert nid_map[src_nid][0] == 'paper' and nid_map[dest_nid][0] == 'venue'
            paper_venue_edge_list.append((nid_map[src_nid][1], nid_map[dest_nid][1])) 
        elif etype == 3:
            assert nid_map[src_nid][0] == 'paper' and nid_map[dest_nid][0] == 'author'
            paper_author_edge_list.append((nid_map[src_nid][1], nid_map[dest_nid][1])) 
        elif etype == 4:
            assert nid_map[src_nid][0] == 'term' and nid_map[dest_nid][0] == 'paper'
            term_paper_edge_list.append((nid_map[src_nid][1], nid_map[dest_nid][1])) 
        elif etype == 5:
            assert nid_map[src_nid][0] == 'venue' and nid_map[dest_nid][0] == 'paper'
            venue_paper_edge_list.append((nid_map[src_nid][1], nid_map[dest_nid][1])) 
        else:
            raise AssertionError 
    
author_paper_edge_index = np.array(author_paper_edge_list, dtype=np.int64).T 
paper_term_edge_index = np.array(paper_term_edge_list, dtype=np.int64).T 
paper_venue_edge_index = np.array(paper_venue_edge_list, dtype=np.int64).T 
paper_author_edge_index = np.array(paper_author_edge_list, dtype=np.int64).T 
term_paper_edge_index = np.array(term_paper_edge_list, dtype=np.int64).T 
venue_paper_edge_index = np.array(venue_paper_edge_list, dtype=np.int64).T  

paper_author_edge_index.shape, \
paper_venue_edge_index.shape, \
paper_term_edge_index.shape 

239566it [00:00, 302345.96it/s]


((2, 19645), (2, 14328), (2, 85810))

In [3]:
with open(os.path.expanduser('~/dataset/HGB/raw/DBLP/label.dat'), 'r', encoding='utf-8') as fp:
    reader = csv.DictReader(fp, fieldnames=['author_nid', 'author_name', 'node_type', 'author_label'], delimiter='\t') 
    
    author_label_arr = np.full([len(author_feat)], fill_value=-1, dtype=np.int64)   
    author_train_mask = np.zeros(len(author_feat), dtype=bool) 
    
    for row in tqdm(reader):
        author_nid = int(row['author_nid']) 
        author_name = row['author_name'].strip() 
        ntype = int(row['node_type']) 
        assert ntype == 0 
        author_label = int(row['author_label'])
        assert author_label in [0, 1, 2, 3] 
        
        assert nid_map[author_nid][0] == 'author'
        assert author_label_arr[nid_map[author_nid][1]] == -1 
        author_label_arr[nid_map[author_nid][1]] = author_label
        author_train_mask[nid_map[author_nid][1]] = True 
    
author_label_arr.shape, \
np.sum(author_label_arr > -1), \
np.sum(author_train_mask) 

1217it [00:00, 63388.28it/s]


((4057,), 1217, 1217)

In [4]:
with open(os.path.expanduser('~/dataset/HGB/raw/DBLP/label.dat.test'), 'r', encoding='utf-8') as fp:
    reader = csv.DictReader(fp, fieldnames=['author_nid', 'author_name', 'node_type', 'author_label'], delimiter='\t') 
    
    for row in tqdm(reader):
        author_nid = int(row['author_nid']) 
        author_name = row['author_name'].strip() 
        ntype = int(row['node_type']) 
        assert ntype == 0 
        author_label = int(row['author_label'])
        assert author_label in [0, 1, 2, 3] 
        
        assert nid_map[author_nid][0] == 'author'
        assert author_label_arr[nid_map[author_nid][1]] == -1 
        author_label_arr[nid_map[author_nid][1]] = author_label
    
author_label_arr.shape, \
np.sum(author_label_arr > -1) 

2840it [00:00, 75693.58it/s]


((4057,), 4057)

In [5]:
import pickle 
import torch 

with open(os.path.expanduser('~/dataset/HGB/processed/DBLP_hg.dict.pkl'), 'wb') as fp:
    pickle.dump(
        dict(
            node_feat_dict = dict(
                author = torch.tensor(author_feat, dtype=torch.float32), 
                paper = torch.tensor(paper_feat, dtype=torch.float32), 
                term = torch.tensor(term_feat, dtype=torch.float32), 
            ), 
            num_nodes_dict = dict(
                author = len(author_feat), 
                paper = len(paper_feat), 
                term = len(term_feat), 
                venue = len(venue_feat_list), 
            ),
            edge_index_dict = {
                ('author', 'AP', 'paper'): torch.tensor(author_paper_edge_index, dtype=torch.int64),
                ('paper', 'PT', 'term'): torch.tensor(paper_term_edge_index, dtype=torch.int64),
                ('paper', 'PV', 'venue'): torch.tensor(paper_venue_edge_index, dtype=torch.int64),
                ('paper', 'PA', 'author'): torch.tensor(paper_author_edge_index, dtype=torch.int64),
                ('term', 'TP', 'paper'): torch.tensor(term_paper_edge_index, dtype=torch.int64),
                ('venue', 'VP', 'paper'): torch.tensor(venue_paper_edge_index, dtype=torch.int64),
            },
            author_label = torch.tensor(author_label_arr, dtype=torch.int64), 
            author_train_mask = torch.tensor(author_train_mask, dtype=torch.bool),  
            author_test_mask = torch.tensor(~author_train_mask, dtype=torch.bool), 
        ), 
        fp, 
    )