In [20]:
import dgl
import dgl.function as fn
from dgl.nn import GATConv
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data.utils import load_graphs


In [21]:
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name = 'ogbn-mag')

split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]
g, l = dataset[0]

In [22]:
g.nodes['paper'].data

{'feat': tensor([[-0.0954,  0.0408, -0.2109,  ...,  0.0616, -0.0277, -0.1338],
        [-0.1510, -0.1073, -0.2220,  ...,  0.3458, -0.0277, -0.2185],
        [-0.1148, -0.1760, -0.2606,  ...,  0.1731, -0.1564, -0.2780],
        ...,
        [ 0.0228, -0.0865,  0.0981,  ..., -0.0547, -0.2077, -0.2305],
        [-0.2891, -0.2029, -0.1525,  ...,  0.1042,  0.2041, -0.3528],
        [-0.0890, -0.0348, -0.2642,  ...,  0.2601, -0.0875, -0.5171]]), 'year': tensor([[2015],
        [2012],
        [2012],
        ...,
        [2016],
        [2017],
        [2014]])}

In [36]:
selected_paper = g.nodes('paper')[:10000]
selected_author = g.nodes('author')[:10000]
selected_field = g.nodes('field_of_study')[:10000]
institution = g.nodes('institution')[:10000]

In [37]:
sub_g = dgl.node_subgraph(g, {'paper': selected_paper, 'author': selected_author, 'field_of_study': selected_field, 'institution': institution})
sub_g

Graph(num_nodes={'author': 10000, 'field_of_study': 10000, 'institution': 8740, 'paper': 10000},
      num_edges={('author', 'affiliated_with', 'institution'): 10503, ('author', 'writes', 'paper'): 622, ('paper', 'cites', 'paper'): 865, ('paper', 'has_topic', 'field_of_study'): 34953},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])

In [39]:
def node_level_subsampling(g, list_of_nodes, node_numbers):
    subsample_data = {}
    if len(list_of_nodes) == 0:
        raise Error('list of nodes are empty')
    
    for node_type in list_of_nodes:
        subsample_data[node_type]=g.nodes(node_type)[:node_numbers]
    
    return dgl.node_subgraph(g,subsample_data)
    

In [41]:
node_level_subsampling(g,['paper','author','field_of_study','institution'],10000)

Graph(num_nodes={'author': 10000, 'field_of_study': 10000, 'institution': 8740, 'paper': 10000},
      num_edges={('author', 'affiliated_with', 'institution'): 10503, ('author', 'writes', 'paper'): 622, ('paper', 'cites', 'paper'): 865, ('paper', 'has_topic', 'field_of_study'): 34953},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])