# Stochastic Training of GNN for Node Classification on Heterogeneous Large Graphs

This tutorial shows how to train a multi-layer R-GCN for node classification on `ogbn-mag` dataset provided by OGB.

The ogbn-mag dataset is a heterogeneous network composed of a subset of the Microsoft Academic Graph (MAG) [1]. It contains four types of entities—papers (736,389 nodes), authors (1,134,649 nodes), institutions (8,740 nodes), and fields of study (59,965 nodes)—as well as four types of directed relations connecting two types of entities—an author is “affiliated with” an institution, an author “writes” a paper, a paper “cites” a paper, and a paper “has a topic of” a field of study.

This tutorial's contents include

* Creating a DGL graph using the dgl ogb data loader.
* Training a GNN model with a single machine, a single GPU, on a graph of any size.

## Load Dataset

Although you can directly use the Python package provided by OGB, for demonstration, we will instead manually download the dataset, peek into its contents, and process it with only `numpy`.

In [1]:
!pip install ogb



In [2]:
from ogb.nodeproppred import DglNodePropPredDataset

dataset = DglNodePropPredDataset(name='ogbn-mag')

Using backend: pytorch


The dataset contains the following:

* DGL graph object (source-destination pairs)
* The node label tensor

We can also use the utility function in the dataset to get the train, validation, test splits

In [3]:
graph, label = dataset[0] # graph: dgl graph object, label: torch tensor of shape (num_nodes, 1)

split_idx = dataset.get_idx_split()
train_nids, valid_nids, test_nids = split_idx["train"], split_idx["valid"], split_idx["test"]

We can see the size of the graph, features, and labels as follows.

In [4]:
print(graph)

print('Node labels')
node_labels = label['paper'].flatten()
print('Shape of target node labels:', node_labels.shape)
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

print('Node features')
node_features = graph.nodes['paper'].data['feat']
num_features = node_features.shape[1]
print('Shape of features of paper node type: {}'.format(num_features))

Graph(num_nodes={'author': 1134649, 'field_of_study': 59965, 'institution': 8740, 'paper': 736389},
      num_edges={('author', 'affiliated_with', 'institution'): 1043998, ('author', 'writes', 'paper'): 7145660, ('paper', 'cites', 'paper'): 5416271, ('paper', 'has_topic', 'field_of_study'): 7505078},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])
Node labels
Shape of target node labels: torch.Size([736389])
Number of classes: 349
Node features
Shape of features of paper node type: 128


## Define a Data Loader with Neighbor Sampling

### Neighbor sampling overview

The formulation of message passing usually has the following form:

$$
\begin{gathered}
  \boldsymbol{a}_v^{(l)} = \rho^{(l)} \left(
    \left\lbrace
      \boldsymbol{h}_u^{(l-1)} : u \in \mathcal{N} \left( v \right)
    \right\rbrace
  \right)
\\
  \boldsymbol{h}_v^{(l)} = \phi^{(l)} \left(
    \boldsymbol{h}_v^{(l-1)}, \boldsymbol{a}_v^{(l)}
  \right)
\end{gathered}
$$

where $\rho^{(l)}$ and $\phi^{(l)}$ are parameterized functions, and $\mathcal{N}(v)$ represents the set of predecessors (or equivalently *neighbors*) of $v$ on graph $\mathcal{G}$:
$$
\mathcal{N} \left( v \right) = \left\lbrace
  s \left( e \right) : e \in \mathbb{E}, t \left( e \right) = v
\right\rbrace
$$

For instance, to perform a message passing for updating the red node in the following graph:

![Imgur](assets/1.png)

One needs to aggregate the node features of its neighbors, shown as green nodes:

![Imgur](assets/2.png)

Let's consider how multi-layer message passing works for computing the output of a single node.  In the following text, we refer to the nodes whose outputs GNN will compute as seed nodes.

Consider computing with a 2-layer GNN the output of the seed node 8, colored red, in the following graph:

![Imgur](assets/seed.png)

By the formulation:

$$
\begin{gathered}
  \boldsymbol{a}_8^{(2)} = \rho^{(2)} \left(
    \left\lbrace
      \boldsymbol{h}_u^{(1)} : u \in \mathcal{N} \left( 8 \right)
    \right\rbrace
  \right) = \rho^{(2)} \left(
    \left\lbrace
      \boldsymbol{h}_4^{(1)}, \boldsymbol{h}_5^{(1)},
      \boldsymbol{h}_7^{(1)}, \boldsymbol{h}_{11}^{(1)}
    \right\rbrace
  \right)
\\
  \boldsymbol{h}_8^{(2)} = \phi^{(2)} \left(
    \boldsymbol{h}_8^{(1)}, \boldsymbol{a}_8^{(2)}
  \right)
\end{gathered}
$$

We can tell from the formulation that, to compute $\boldsymbol{h}_8^{(2)}$, we need messages from node 4, 5, 7, and 11 (colored green) along the edges visualized below.

![Imgur](assets/3.png)

The values of $\boldsymbol{h}_\cdot^{(1)}$ are the outputs from the first GNN layer.  To compute those values for the red and green nodes, we further need to perform message passing on the edges visualized below.

![Imgur](assets/4.png)

Therefore, to compute the 2-layer GNN representation of the red node, we need the input features from the red node as well as the green and yellow nodes.  Note that we should take red node's neighbors again for this layer.

You may notice that the procedure which determines computation dependency is in the reverse direction of message aggregation: you start from the layer closest to the output and work backward to the input.

You can also see that computing representation for a small number of nodes often requires input features of a significantly larger number of nodes.  Taking all neighbors for message aggregation is often too costly since the nodes needed would easily cover a large portion of the graph.

Neighbor sampling addresses this issue by selecting a random subset of the neighbors to perform aggregation.  For instance, to compute $\boldsymbol{h}_8^{(1)}$, we can choose to sample 2 neighbors and aggregate.

![Imgur](assets/5.png)

Similarly, to compute the red and green nodes' first layer representation, we can also do neighbor sampling that takes 2 neighbors for each node.  Note that we should take the red node's neighbors again for this layer.

![Imgur](assets/6.png)

You can see that this method could give us fewer nodes needed for input features.

### Defining neighbor sampler and data loader in DGL

DGL provides useful tools to generate such computation dependencies while iterating over the dataset in minibatches.  For node classification, you can use `dgl.dataloading.NodeDataLoader` for iterating over the dataset, and `dgl.dataloading.MultiLayerNeighborSampler` to generate computation dependencies of the nodes from a multi-layer GNN with neighbor sampling.

The syntax of `dgl.dataloading.NodeDataLoader` is mostly similar to a PyTorch `DataLoader`, with the addition that it needs a graph to generate computation dependency from, a set of node IDs to iterate on, and the neighbor sampler you defined.

Let's consider training a 3-layer R-GCN with neighbor sampling, and each node will gather message from 4 neighbors on each layer.  The code defining the data loader and neighbor sampler will look like the following.

In [5]:
import dgl

sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

We can iterate over the data loader we created and see what it gives us.

In [6]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

({'author': tensor([365776,  60441, 198717,  ...,  41465,  65043, 649864]), 'field_of_study': tensor([], dtype=torch.int64), 'institution': tensor([], dtype=torch.int64), 'paper': tensor([224131,   2074,  13576,  ..., 716197, 698763, 661550])}, {'author': tensor([], dtype=torch.int64), 'field_of_study': tensor([], dtype=torch.int64), 'institution': tensor([], dtype=torch.int64), 'paper': tensor([224131,   2074,  13576,  ..., 190503,  58541, 207510])}, [Block(num_src_nodes={'author': 27329, 'field_of_study': 0, 'institution': 0, 'paper': 22345},
      num_dst_nodes={'author': 10952, 'field_of_study': 0, 'institution': 0, 'paper': 10004},
      num_edges={('author', 'affiliated_with', 'institution'): 0, ('author', 'writes', 'paper'): 32706, ('paper', 'cites', 'paper'): 22723, ('paper', 'has_topic', 'field_of_study'): 0},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')]), Bl

`NodeDataLoader` gives us three items per iteration.

* The input node list for the nodes whose input features are needed to compute the outputs.
* The output node list whose GNN representation are to be computed.
* The list of computation dependency for each layer.

In [7]:
input_nodes, output_nodes, bipartites = example_minibatch
print("To compute {} target nodes' output we need {} nodes' input features".format(len(output_nodes['paper']), len(input_nodes['paper'])))

print("")
print("Output nodes")
print(output_nodes)

print("")
print("Input nodes")
print(input_nodes)

To compute 1024 target nodes' output we need 22345 nodes' input features

Output nodes
{'author': tensor([], dtype=torch.int64), 'field_of_study': tensor([], dtype=torch.int64), 'institution': tensor([], dtype=torch.int64), 'paper': tensor([224131,   2074,  13576,  ..., 190503,  58541, 207510])}

Input nodes
{'author': tensor([365776,  60441, 198717,  ...,  41465,  65043, 649864]), 'field_of_study': tensor([], dtype=torch.int64), 'institution': tensor([], dtype=torch.int64), 'paper': tensor([224131,   2074,  13576,  ..., 716197, 698763, 661550])}


The variable `bipartites` shows how messages aggregate on each layer.  As its name suggests, it is a **list** of bipartite graphs.

So why does DGL return a list of *bipartite* graphs for training a *homogeneous* graph?  The reason is that the number of nodes for input and that for output of a given GNN layer is different.  Take the example above:

![Imgur](assets/6.png)

That GNN layer will output the representation of three nodes (two green nodes and one red node), but it will require input from 7 nodes (the green nodes and red node, plus 4 yellow nodes).  Only a bipartite graph can describe such computation:

![](assets/bipartite.png)

Minibatch training of GNNs usually involves message passing on such bipartite graphs.

In [8]:
for block in bipartites:
    print(block)
    print()

Block(num_src_nodes={'author': 27329, 'field_of_study': 0, 'institution': 0, 'paper': 22345},
      num_dst_nodes={'author': 10952, 'field_of_study': 0, 'institution': 0, 'paper': 10004},
      num_edges={('author', 'affiliated_with', 'institution'): 0, ('author', 'writes', 'paper'): 32706, ('paper', 'cites', 'paper'): 22723, ('paper', 'has_topic', 'field_of_study'): 0},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writes'), ('paper', 'paper', 'cites'), ('paper', 'field_of_study', 'has_topic')])

Block(num_src_nodes={'author': 10952, 'field_of_study': 0, 'institution': 0, 'paper': 10004},
      num_dst_nodes={'author': 3263, 'field_of_study': 0, 'institution': 0, 'paper': 3748},
      num_edges={('author', 'affiliated_with', 'institution'): 0, ('author', 'writes', 'paper'): 12154, ('paper', 'cites', 'paper'): 8975, ('paper', 'has_topic', 'field_of_study'): 0},
      metagraph=[('author', 'institution', 'affiliated_with'), ('author', 'paper', 'writ

## Defining Model

The model can be written as follows:

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

class RGCN(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, rel_names):
        super().__init__()
        
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        
        self.layers.append(dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, n_hidden)
            for rel in rel_names}, aggregate='sum'))
        
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.HeteroGraphConv({
                rel: dglnn.GraphConv(n_hidden, n_hidden)
                for rel in rel_names}, aggregate='sum'))
            
        self.layers.append(dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(n_hidden, n_classes)
            for rel in rel_names}, aggregate='sum'))

    def forward(self, bipartites, x):
        # inputs are features of nodes
        for l, (layer, bipartite) in enumerate(zip(self.layers, bipartites)):
            x = layer(bipartite, x)
            if l != self.n_layers - 1:
                x = {k: F.relu(v) for k, v in x.items()}
        return x
    

class NodeEmbed(nn.Module):
    def __init__(self, num_nodes, embed_size,):
        super(NodeEmbed, self).__init__()
        self.embed_size = embed_size
        self.node_embeds = nn.ModuleDict()
        for ntype in num_nodes:
            node_embed = torch.nn.Embedding(num_nodes[ntype], self.embed_size)
            nn.init.uniform_(node_embed.weight, -1.0, 1.0)
            self.node_embeds[str(ntype)] = node_embed
    
    def forward(self, node_ids):
        embeds = {}
        for ntype in node_ids:
            embeds[ntype] = self.node_embeds[ntype](node_ids[ntype])
        return embeds

You can see that here we are iterating over the pairs of NN module layer and bipartite graphs generated by the data loader.

## Defining Training Loop

The following initializes the model and defines the optimizer.

In [10]:
num_nodes = {ntype: graph.number_of_nodes(ntype) for ntype in graph.ntypes if ntype != 'paper'}
embed = NodeEmbed(num_nodes, 128)
model = RGCN(num_features, 128, num_classes, 3, graph.etypes).cuda()
opt = torch.optim.Adam(model.parameters())

In [11]:
embed

NodeEmbed(
  (node_embeds): ModuleDict(
    (author): Embedding(1134649, 128)
    (field_of_study): Embedding(59965, 128)
    (institution): Embedding(8740, 128)
  )
)

In [12]:
model

RGCN(
  (layers): ModuleList(
    (0): HeteroGraphConv(
      (mods): ModuleDict(
        (affiliated_with): GraphConv(in=128, out=128, normalization=both, activation=None)
        (cites): GraphConv(in=128, out=128, normalization=both, activation=None)
        (has_topic): GraphConv(in=128, out=128, normalization=both, activation=None)
        (writes): GraphConv(in=128, out=128, normalization=both, activation=None)
      )
    )
    (1): HeteroGraphConv(
      (mods): ModuleDict(
        (affiliated_with): GraphConv(in=128, out=128, normalization=both, activation=None)
        (cites): GraphConv(in=128, out=128, normalization=both, activation=None)
        (has_topic): GraphConv(in=128, out=128, normalization=both, activation=None)
        (writes): GraphConv(in=128, out=128, normalization=both, activation=None)
      )
    )
    (2): HeteroGraphConv(
      (mods): ModuleDict(
        (affiliated_with): GraphConv(in=128, out=349, normalization=both, activation=None)
        (cites): 

When computing the validation score for model selection, usually you can also do neighbor sampling.  To do that, you need to define another data loader.

In [13]:
valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

The following is a training loop that performs validation every epoch.  It also saves the model with the best validation accuracy into a file.

In [14]:
import tqdm
import numpy as np
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(1):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            
            # Get node ids for node types that don't have input features
            nodes_to_embed = {ntype: node_ids for ntype, node_ids in input_nodes.items() if ntype != "paper"}
            
            # Get node embeddings for node types that don't have input features and copy to gpu
            embeddings = {ntype: node_embedding.cuda() for ntype, node_embedding in embed(nodes_to_embed).items()}
            
            # Get input features for node type 'paper' which has input features
            inputs = {'paper': node_features[input_nodes['paper']].cuda()}
            
            # Merge feature inputs with input that has features
            inputs.update(embeddings)
            
            labels = node_labels[output_nodes['paper']].cuda()
            predictions = model(bipartites, inputs)['paper']

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
            
            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
        
    model.eval()
    
    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, bipartites in tq:
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            
            nodes_to_embed = {ntype: node_ids for ntype, node_ids in input_nodes.items() if ntype != "paper"}
            embeddings = {ntype: node_embedding.cuda() for ntype, node_embedding in embed(nodes_to_embed).items()}
            inputs = {'paper': node_features[input_nodes['paper']].cuda()}
            inputs.update(embeddings)
            
            labels.append(node_labels[output_nodes['paper']].numpy())
            predictions.append(model(bipartites, inputs)['paper'].argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

100%|██████████| 615/615 [01:42<00:00,  6.01it/s, loss=3.945, acc=0.157]
100%|██████████| 64/64 [00:02<00:00, 27.76it/s]

Epoch 0 Validation Accuracy 0.08822577413338677





## Offline Inference without Neighbor Sampling

Usually for offline inference it is desirable to aggregate over the entire neighborhood to eliminate randomness introduced by neighbor sampling.  However, using the same methodology in training is not efficient, because there will be a lot of redundant computation.  Moreover, simply doing neighbor sampling by taking all neighbors will often exhaust GPU memory because the number of nodes required for input features may be too large to fit into GPU memory.

Instead, you need to compute the representations layer by layer: you first compute the output of the first GNN layer for all nodes, then you compute the output of second GNN layer for all nodes using the first GNN layer's output as input, etc.  This gives us a different algorithm from what is being used in training.  During training we have an outer loop that iterates over the nodes, and an inner loop that iterates over the layers.  In contrast, during inference we have an outer loop that iterates over the layers, and an inner loop that iterates over the nodes.

If you do not care about randomness too much (e.g., during model selection in validation), you can still use the `dgl.dataloading.MultiLayerNeighborSampler` and `dgl.dataloading.NodeDataLoader` to do offline inference, since it is usually faster for evaluating a small number of nodes.

![Imgur](assets/anim.gif)

In [15]:
def inference(model, graph, input_features, batch_size):
    nodes = {'paper': torch.arange(graph.number_of_nodes('paper'))}
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    with torch.no_grad():
        for l, layer in enumerate(model.layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = {ntype: torch.zeros(
                graph.number_of_nodes(ntype), model.n_hidden if l != model.n_layers - 1 else model.n_classes)
                for ntype in graph.ntypes}

            for input_nodes, output_nodes, bipartites in tqdm.tqdm(dataloader):
                bipartite = bipartites[0].to(torch.device('cuda'))

                # 
                x = {ntype: input_features[ntype][input_nodes[ntype]].cuda() for ntype in input_nodes}

                # the following code is identical to the loop body in model.forward()
                x = layer(bipartite, x)
                if l != model.n_layers - 1:
                    x = {k: F.relu(v) for k, v in x.items()}
                
                for ntype in x:
                    output_features[ntype][output_nodes[ntype]] = x[ntype].cpu()
            input_features = output_features
    return output_features

The following code loads the best model from the file saved previously and performs offline inference.  It computes the accuracy on the test set afterwards.

In [16]:
model.load_state_dict(torch.load(best_model_path))

nodes_to_embed = {ntype: torch.arange(num_nodes_ntype) for ntype, num_nodes_ntype in num_nodes.items()}
embeddings = {ntype: node_embedding for ntype, node_embedding in embed(nodes_to_embed).items()}
inputs = {'paper': node_features}
inputs.update(embeddings)

all_predictions = inference(model, graph, inputs, 8192)

100%|██████████| 90/90 [00:07<00:00, 12.12it/s]
100%|██████████| 90/90 [00:07<00:00, 12.06it/s]
100%|██████████| 90/90 [00:07<00:00, 11.90it/s]


In [17]:
test_predictions = all_predictions['paper'][test_nids['paper']].argmax(1)
test_labels = node_labels[test_nids['paper']]
test_accuracy = sklearn.metrics.accuracy_score(test_predictions.numpy(), test_labels.numpy())
print('Test accuracy:', test_accuracy)

Test accuracy: 0.03962898495433844


## Conclusion

In this tutorial, you have learned how to train a multi-layer RGCN with neighbor sampling on a large heterogeneous dataset that cannot fit into a single GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.

## What's next?

The next tutorial will be about training the same GraphSAGE model in an unsupervised manner with link prediction, i.e. predicting whether an edge exist between two nodes.