In [11]:
import jojo_graph 
from jojo_graph.imports import * 

dataset = jojo_graph.hetero_graph_dataset.OGB.OgbnMag()
graph = dataset.graph 
feat = dataset.feat 
label = dataset.label 
train_mask = dataset.train_mask
val_mask = dataset.val_mask
test_mask = dataset.test_mask

graph 

HeteroGraph:
{'num_nodes_dict': {'author': 1134649,
                    'institution': 8740,
                    'paper': 736389,
                    'field': 59965},
 'total_num_nodes': 1939743,
 'num_edges_dict': {('author', 'ai', 'institution'): 1043998,
                    ('author', 'ap', 'paper'): 7145660,
                    ('paper', 'pp_cites', 'paper'): 5416271,
                    ('paper', 'pf', 'field'): 7505078,
                    ('institution', 'ia', 'author'): 1043998,
                    ('paper', 'pa', 'author'): 7145660,
                    ('paper', 'pp_cited', 'paper'): 5416271,
                    ('field', 'fp', 'paper'): 7505078},
 'total_num_edges': 42222014,
 'device': 'cpu'}

In [12]:
counter = Counter(label.tolist())

label_subset = counter.most_common(5)
label_subset

[(1, 30902), (134, 30671), (300, 27542), (258, 24804), (283, 24463)]

In [13]:
label_subset_cnt = sum(x[1] for x in label_subset)
label_subset_cnt, label_subset_cnt / len(label)

(138382, 0.1879197000498378)

In [14]:
label_subset = { x[0] for x in label_subset }
label_subset

{1, 134, 258, 283, 300}

In [15]:
paper_id_subset = { i for i, lb in enumerate(label.tolist()) if lb in label_subset }
len(paper_id_subset)

138382

In [16]:
def generate_ogbn_mag_subset(hg: jojo_graph.HeteroGraph,
                             paper_id_subset: set[int]) -> jojo_graph.HeteroGraph:
    PA_edge_index = hg.edge_index_dict[('paper', 'pa', 'author')].tolist()
    PF_edge_index = hg.edge_index_dict[('paper', 'pf', 'field')].tolist()
    PP_edge_index = hg.edge_index_dict[('paper', 'pp_cites', 'paper')].tolist()
    AI_edge_index = hg.edge_index_dict[('author', 'ai', 'institution')].tolist()
    
    paper_nid_to_oid_map: list[int] = sorted(paper_id_subset)

    paper_oid_to_nid_map: dict[int, int] = { 
        oid: nid 
        for nid, oid in enumerate(paper_nid_to_oid_map) 
    } 
    
    author_oid_to_nid_map: dict[int, int] = dict() 
    paper_nid_to_author_nid_map: dict[int, set[int]] = defaultdict(set)

    for paper_oid, author_oid in zip(*PA_edge_index):
        if paper_oid in paper_oid_to_nid_map:
            paper_nid = paper_oid_to_nid_map[paper_oid]
            
            if author_oid not in author_oid_to_nid_map:
                author_oid_to_nid_map[author_oid] = len(author_oid_to_nid_map)
            
            author_nid = author_oid_to_nid_map[author_oid]
            paper_nid_to_author_nid_map[paper_nid].add(author_nid)

    field_oid_to_nid_map: dict[int, int] = dict() 
    paper_nid_to_field_nid_map: dict[int, set[int]] = defaultdict(set)

    for paper_oid, field_oid in zip(*PF_edge_index):
        if paper_oid in paper_oid_to_nid_map:
            paper_nid = paper_oid_to_nid_map[paper_oid]
            
            if field_oid not in field_oid_to_nid_map:
                field_oid_to_nid_map[field_oid] = len(field_oid_to_nid_map)
            
            field_nid = field_oid_to_nid_map[field_oid]
            paper_nid_to_field_nid_map[paper_nid].add(field_nid)
    
    institution_oid_to_nid_map: dict[int, int] = dict() 
    author_nid_to_institution_nid_map: dict[int, set[int]] = defaultdict(set) 

    for author_oid, institution_oid in zip(*AI_edge_index):
        if author_oid in author_oid_to_nid_map:
            author_nid = author_oid_to_nid_map[author_oid]
            
            if institution_oid not in institution_oid_to_nid_map:
                institution_oid_to_nid_map[institution_oid] = len(institution_oid_to_nid_map)
                
            institution_nid = institution_oid_to_nid_map[institution_oid]
            author_nid_to_institution_nid_map[author_nid].add(institution_nid) 
            
    new_PA_edge_index: tuple[list[int], list[int]] = [], [] 
    for paper_nid in paper_nid_to_author_nid_map:
        for author_nid in paper_nid_to_author_nid_map[paper_nid]:
            new_PA_edge_index[0].append(paper_nid)
            new_PA_edge_index[1].append(author_nid)

    new_PF_edge_index: tuple[list[int], list[int]] = [], [] 
    for paper_nid in paper_nid_to_field_nid_map:
        for field_nid in paper_nid_to_field_nid_map[paper_nid]:
            new_PF_edge_index[0].append(paper_nid)
            new_PF_edge_index[1].append(field_nid)
            
    new_AI_edge_index: tuple[list[int], list[int]] = [], [] 
    for author_nid in author_nid_to_institution_nid_map:
        for institution_nid in author_nid_to_institution_nid_map[author_nid]:
            new_AI_edge_index[0].append(author_nid)
            new_AI_edge_index[1].append(institution_nid)
            
    new_PP_edge_index: tuple[list[int], list[int]] = [], []
    for paper_oid_1, paper_oid_2 in zip(*PP_edge_index): 
        if paper_oid_1 in paper_oid_to_nid_map and paper_oid_2 in paper_oid_to_nid_map: 
            paper_nid_1 = paper_oid_to_nid_map[paper_oid_1]
            paper_nid_2 = paper_oid_to_nid_map[paper_oid_2]
            new_PP_edge_index[0].append(paper_nid_1)
            new_PP_edge_index[1].append(paper_nid_2)
            
    new_edge_index_dict: dict[EdgeType, IntTensor] = {
        ('author', 'ai', 'institution'): torch.tensor(new_AI_edge_index), 
        ('author', 'ap', 'paper'): torch.tensor(new_PA_edge_index[::-1]), 
        ('paper', 'pp_cites', 'paper'): torch.tensor(new_PP_edge_index), 
        ('paper', 'pf', 'field'): torch.tensor(new_PF_edge_index), 
        ('institution', 'ia', 'author'): torch.tensor(new_AI_edge_index[::-1]), 
        ('paper', 'pa', 'author'): torch.tensor(new_PA_edge_index), 
        ('paper', 'pp_cited', 'paper'): torch.tensor(new_PP_edge_index[::-1]), 
        ('field', 'fp', 'paper'): torch.tensor(new_PF_edge_index[::-1]), 
    }
    
    new_num_nodes_dict: dict[NodeType, int] = {
        'paper': len(paper_oid_to_nid_map),
        'author': len(author_oid_to_nid_map),
        'field': len(field_oid_to_nid_map),
        'institution': len(institution_oid_to_nid_map),
    }
    
    return jojo_graph.HeteroGraph(
        edge_index_dict = new_edge_index_dict, 
        num_nodes_dict = new_num_nodes_dict, 
        device = hg.device,  
    )

In [17]:
new_hg = generate_ogbn_mag_subset(graph, paper_id_subset)
new_hg 

HeteroGraph:
{'num_nodes_dict': {'paper': 138382,
                    'author': 260402,
                    'field': 21680,
                    'institution': 4936},
 'total_num_nodes': 425400,
 'num_edges_dict': {('author', 'ai', 'institution'): 279401,
                    ('author', 'ap', 'paper'): 928352,
                    ('paper', 'pp_cites', 'paper'): 976341,
                    ('paper', 'pf', 'field'): 1458089,
                    ('institution', 'ia', 'author'): 279401,
                    ('paper', 'pa', 'author'): 928352,
                    ('paper', 'pp_cited', 'paper'): 976341,
                    ('field', 'fp', 'paper'): 1458089},
 'total_num_edges': 7284366,
 'device': 'cpu'}

In [18]:
paper_oid_list = sorted(paper_id_subset)

new_paper_feat = feat[paper_oid_list]

new_label = label[paper_oid_list]
label_oid_to_nid_map: dict[int, int] = dict() 
for label_oid in new_label.tolist():
    if label_oid not in label_oid_to_nid_map:
        label_oid_to_nid_map[label_oid] = len(label_oid_to_nid_map)
new_label = torch.tensor([ label_oid_to_nid_map[label_oid] for label_oid in new_label.tolist() ], dtype=torch.int64)

new_train_mask = train_mask[paper_oid_list]
new_val_mask = val_mask[paper_oid_list]
new_test_mask = test_mask[paper_oid_list]

print(len(new_label))
print(new_train_mask.sum(), new_train_mask.sum() / len(new_label))
print(new_val_mask.sum(), new_val_mask.sum() / len(new_label))
print(new_test_mask.sum(), new_test_mask.sum() / len(new_label))

138382
tensor(122828) tensor(0.8876)
tensor(8806) tensor(0.0636)
tensor(6748) tensor(0.0488)


In [19]:
print(Counter(new_label[new_train_mask].tolist()))
print(Counter(new_label[new_val_mask].tolist()))
print(Counter(new_label[new_test_mask].tolist()))

Counter({0: 28600, 3: 27449, 2: 22771, 4: 22689, 1: 21319})
Counter({1: 3447, 3: 2071, 0: 1480, 2: 1126, 4: 682})
Counter({1: 2776, 4: 1433, 3: 1382, 0: 591, 2: 566})


In [20]:
graph_info = dict(
    edge_index_dict = new_hg.edge_index_dict,
    num_nodes_dict = new_hg.num_nodes_dict, 
    feat_dict = dict(
        paper = new_paper_feat, 
        author = torch.rand(new_hg.num_nodes_dict['author'], new_paper_feat.shape[-1], dtype=torch.float32), 
        field = torch.rand(new_hg.num_nodes_dict['field'], new_paper_feat.shape[-1], dtype=torch.float32), 
        institution = torch.rand(new_hg.num_nodes_dict['institution'], new_paper_feat.shape[-1], dtype=torch.float32), 
    ), 
    label = new_label, 
    train_mask = new_train_mask,   
    val_mask = new_val_mask,   
    test_mask = new_test_mask,   
)

with open('/home/genghao/dataset/hetero_graph/OGB/ogbn-mag/processed/ogbn-mag_field_top5.dict.pkl', 'wb') as fp:
    pickle.dump(graph_info, fp)