# Training GNNs on Large Graphs

We have seen the example of training GNNs on the entire graph.  However, usually our graph is very big: it could contain millions or billions of nodes and edges.  The storage required for the graph would be many times bigger if we consider node and edge features.  If we want to utilize GPUs for faster computation, we would notice that full graph training is often impossible on GPUs because our graph and features cannot fit into a single GPU.  Not to mention that the node representation of intermediate layers are also stored for the sake of backpropagation.

To get over this limit, we employ two methodologies:

1. Stochastic training on graphs.
2. Neighbor sampling on graphs.

### GraphSAGE Recap

In previous session, we have discussed GraphSAGE model. The output representation $h_v^k$ of node $v$ from the $k$th layer is simply computed by:

$$h_{N(v)}^k \leftarrow AGGREGATE_k({h_u^{k-1}, \forall u \in N(v)})$$

$$h_v^k \leftarrow \sigma(W^k CONCAT(h_v^{k-1}, h_{N(v)}^k))$$

Note: the input of a GraphSage layer includes the neighbors' representation from the previous layer as well as the destination nodes' representation from the previous layer.

In [None]:
import dgl
import dgl.function as fn
from dgl.nn.pytorch import conv as dgl_conv

import torch
from torch import nn
import torch.nn.functional as F

## Mini-batch Construction from a Graph

For stochastic training, we want to split training data into small mini-batches and only put necessary information into GPU for each step of training. In case of node classification, we want to split the labeld nodes into mini-batches. Let take a deep look of what information is necessary for a mini-batch of nodes.

Let's use NetworkX to construct the toy graph. *Note*: please don't use NetworkX when you want to scale to large graphs.

In [None]:
# A small graph

import networkx as nx

example_graph_nx = nx.Graph(
    [(0, 2), (0, 4), (0, 6), (0, 7), (0, 8), (0, 9), (0, 10),
     (1, 2), (1, 3), (1, 5), (2, 3), (2, 4), (2, 6), (3, 5),
     (3, 8), (4, 7), (8, 9), (8, 11), (9, 10), (9, 11)])

example_graph = dgl.graph(example_graph_nx)

![](assets/graph.png)

### Single layer

If we wish to compute the output representation of node 4 and 6 with a GraphSAGE layer, we actually need the input feature of node 4 and 6 themselves, as well as their neighbors (node 7, 0 and 2):

To construct a mini-batch with neighbor sampling, we can use DGL API: `dgl.sample_neighbors`, that takes in a set of nodes and returns a graph consisting of a specified number of edges going to one of the given nodes.  Such a graph can exactly describe the computation dependency above.

In [None]:
sampled_node_batch = torch.LongTensor([4, 6])   # These are the nodes whose outputs are to be computed
sampled_graph = dgl.sampling.sample_neighbors(example_graph, sampled_node_batch, 2)
print('|V|={}, |E|={}'.format(sampled_graph.number_of_nodes(), sampled_graph.number_of_edges()))
src, dst = sampled_graph.all_edges()
for s, d in zip(src, dst):
    print(s.numpy(), d.numpy())

DGL further provides a bipartite structure *block* to better reflect this data structure. A sub graph can be easily converted to a block with function `dgl.to_block`.

In [None]:
sampled_block = dgl.to_block(sampled_graph, sampled_node_batch)

def print_block_info(sampled_block):
    print('#source:', sampled_block.number_of_src_nodes())
    sampled_input_nodes = sampled_block.srcdata[dgl.NID]
    print('Node ID of source nodes in original graph:', sampled_input_nodes)

    sampled_output_nodes = sampled_block.dstdata[dgl.NID]
    print('#destination:', sampled_block.number_of_dst_nodes())
    print('Node ID of destination nodes in original graph:', sampled_output_nodes)

    sampled_block_edges_src, sampled_block_edges_dst = sampled_block.all_edges()
    print('edges in local node Ids')
    for s, d in zip(sampled_block_edges_src, sampled_block_edges_dst):
        print(s.numpy(), d.numpy())
    # We need to map the src and dst node IDs in the blocks to the node IDs in the original graph.
    sampled_block_edges_src_mapped = sampled_input_nodes[sampled_block_edges_src]
    sampled_block_edges_dst_mapped = sampled_output_nodes[sampled_block_edges_dst]
    print('edges in the original node Ids')
    for s, d in zip(sampled_block_edges_src_mapped, sampled_block_edges_dst_mapped):
        print(s.numpy(), d.numpy())
    
print_block_info(sampled_block)

### Multiple Layers

Now we wish to compute the output of node 4 and 6 from a 2-layer GraphSAGE.  This requires the input features of not only the nodes themselves and their neighbors, but also the neighbors of these neighbors.

To compute the 2-layer output of node 4 and 6, we first need to obtain the 1-layer output of node 4 and 6, as well as the neighbors.  To obtain the 1-layer output of all these nodes, we again need the input feature of these nodes as well as *their* neighbors.

We can see that the generation of computation dependency for multi-layer GNNs is a bottom-up process: we start from the output layer, and grows the node set towards the input layer.

The following code directly returns the list of blocks as the computation dependency generation for multi-layer GNNs.

In [None]:
class NeighborSampler(object):
    def __init__(self, g, num_fanouts):
        """
        num_fanouts : list of fanouts on each layer.
        """
        self.g = g
        self.num_fanouts = num_fanouts
        
    def sample(self, seeds):
        seeds = torch.LongTensor(seeds)
        blocks = []
        for fanout in reversed(self.num_fanouts):
            # We simply switch from in_subgraph to sample_neighbors for neighbor sampling.
            if fanout >= self.g.number_of_nodes():
                sampled_graph = dgl.in_subgraph(self.g, seeds)
            else:
                sampled_graph = dgl.sampling.sample_neighbors(self.g, seeds, fanout)
            
            sampled_block = dgl.to_block(sampled_graph, seeds)
            seeds = sampled_block.srcdata[dgl.NID]
            # Because the computation dependency is generated bottom-up, we prepend the new block instead of
            # appending it.
            blocks.insert(0, sampled_block)
            
        return blocks

In [None]:
block_sampler = NeighborSampler(example_graph, [2, 2])
sampled_blocks = block_sampler.sample(sampled_node_batch)

print('#blocks:', len(sampled_blocks))
print('Block for first layer')
print('---------------------')
print_block_info(sampled_blocks[0])
print()
print('Block for second layer')
print('----------------------')
print_block_info(sampled_blocks[1])

## Minibatch training for 2-layer GraphSage

### GraphSAGE on blocks

The sampled block is ensantially a bipartite graph. We have seen in previous example that DGL's built-in class `SAGEConv` works perfectly on whole graph. Does it also function properly on a *Block*? The answer is yes. Acutally all of DGL's neural network layers support working on both homogeneous graphs and bipartite graphs.

In [None]:
import dgl.nn as dglnn

class SAGENet(nn.Module):
    def __init__(self, n_layers, in_feats, out_feats, hidden_feats=None):
        super().__init__()
        self.convs = nn.ModuleList()
        
        if hidden_feats is None:
            hidden_feats = out_feats
        
        if n_layers == 1:
            self.convs.append(dglnn.SAGEConv(in_feats, out_feats, 'mean'))
        else:
            self.convs.append(dglnn.SAGEConv(in_feats, hidden_feats, 'mean', activation=F.relu))
            for i in range(n_layers - 2):
                self.convs.append(dglnn.SAGEConv(hidden_feats, hidden_feats, 'mean', activation=F.relu))
            self.convs.append(dglnn.SAGEConv(hidden_feats, out_feats, 'mean'))
        
    def forward(self, blocks, input_features):
        """
        blocks : List of blocks generated by block sampler.
        input_features : Input features of the first block.
        """
        h = input_features
        for layer, block in zip(self.convs, blocks):
            h = self.propagate(block, h, layer)
        return h
    
    def propagate(self, block, src_feats, layer):
        # Because GraphSAGE requires not only the features of the neighbors, but also the features
        # of the output nodes themselves on the current layer, we need to copy the output node features
        # from the input side to the output side ourselves to make GraphSAGE work correctly.
        # The output nodes of a block are guaranteed to appear the first in the input nodes, so we can
        # conveniently write like this:
        dst_feats = src_feats[:block.number_of_dst_nodes()]
        return layer(block, (src_feats, dst_feats))

### Inference with mini-batch

Inference can be computed in a mini-batch fashion. For a multi-layer GraphSAGE model, we first compute the representation of all nodes on the 1st GraphSAGE layer that takes all neighbors into account. After all the representations from the 1st GraphSAGE layer are computed, we start from there and compute the representation of all nodes on the 2nd GraphSAGE layer.  We repeat the process until we go to the last layer.

In [None]:
def inference_with_sagenet(sagenet, graph, input_features, batch_size):
    block_sampler = NeighborSampler(graph, [graph.number_of_nodes()])
    h = input_features
    
    with torch.no_grad():
        # We are computing all representations of one layer at a time.
        # The outer loop iterates over GNN layers.
        for conv in sagenet.convs:
            new_h_list = []
            node_ids = torch.arange(graph.number_of_nodes())
            # The inner loop iterates over batch of nodes.
            for batch_start in range(0, graph.number_of_nodes(), batch_size):
                # Sample a block with full neighbors of the current node batch
                block = block_sampler.sample(node_ids[batch_start:batch_start+batch_size])[0]
                # Get the necessary input node IDs for this node batch on this layer
                input_node_ids = block.srcdata[dgl.NID]
                # Get the input features
                h_input = h[input_node_ids]
                # Compute the output of this node batch on this layer
                new_h = sagenet.propagate(block, h_input, conv)
                new_h_list.append(new_h)
            # We finished computing all representations on this layer.  We need to compute the
            # representations of next layer.
            h = torch.cat(new_h_list)
        
    return h

### Load Dataset

Load a builtin dataset from DGL: a citation network of pubmed, where nodes are papers and edges are citations.

DGL provides many builtin datasets [here](https://doc.dgl.ai/api/python/data.html).

In [None]:
import dgl.data

dataset = dgl.data.citation_graph.load_pubmed()

# Set features and labels for each node
graph = dgl.graph(dataset.graph)
graph.ndata['features'] = torch.FloatTensor(dataset.features)
graph.ndata['labels'] = torch.LongTensor(dataset.labels)
in_feats = dataset.features.shape[1]
num_labels = dataset.num_labels

# Find the node IDs in the training, validation, and test set.
train_nid = dataset.train_mask.nonzero()[0]
val_nid = dataset.val_mask.nonzero()[0]
test_nid = dataset.test_mask.nonzero()[0]

The [Amazon product co-purchasing network](https://ogb.stanford.edu/docs/nodeprop/#ogbn-products) from [Open Graph Benchmark](https://ogb.stanford.edu/).

In [None]:
from ogb.nodeproppred import DglNodePropPredDataset

data = DglNodePropPredDataset(name='ogbn-products')
splitted_idx = data.get_idx_split()
graph, labels = data[0]
labels = labels[:, 0]
graph = dgl.as_heterograph(graph)

graph.ndata['features'] = graph.ndata['feat']
graph.ndata['labels'] = labels
in_feats = graph.ndata['features'].shape[1]
num_labels = len(torch.unique(labels))

# Find the node IDs in the training, validation, and test set.
train_nid, val_nid, test_nid = splitted_idx['train'], splitted_idx['valid'], splitted_idx['test']
print('|V|={}, |E|={}'.format(graph.number_of_nodes(), graph.number_of_edges()))
print('train: {}, valid: {}, test: {}'.format(len(train_nid), len(val_nid), len(test_nid)))

### Define Neighbor Sampler

We can reuse our neighbor sampler code above.

In [None]:
neighbor_sampler = NeighborSampler(graph, [10, 25])

### Define DataLoader

PyTorch generates minibatches with a `DataLoader` object.  We can also use it.

Note that to compute the output of a minibatch of nodes, we need a list of blocks described as above.  Therefore, we need to change the `collate_fn` argument which defines how to compose different individual examples into a minibatch.

The benefit of using Pytorch Dataloader is that we can take advantage of multiprocessing in DataLoader to generate mini-batches in parallel.

In [None]:
import torch.utils.data

BATCH_SIZE = 1000

train_dataloader = torch.utils.data.DataLoader(
    train_nid, batch_size=BATCH_SIZE, collate_fn=neighbor_sampler.sample, shuffle=True)

### Define Model and Optimizer

In [None]:
HIDDEN_FEATURES = 50
model = SAGENet(2, in_feats, num_labels, HIDDEN_FEATURES)

opt = torch.optim.Adam(model.parameters(), lr=1e-3)

### Evaluation

In [None]:
def compute_accuracy(pred, labels):
    return (pred.argmax(1) == labels).float().mean().item()

### Training Loop

In [None]:
NUM_EPOCHS = 200
EVAL_BATCH_SIZE = 10000
for epoch in range(NUM_EPOCHS):
    model.train()
    for blocks in train_dataloader:
        input_nodes = blocks[0].srcdata[dgl.NID]
        output_nodes = blocks[-1].dstdata[dgl.NID]
        
        input_features = graph.ndata['features'][input_nodes]
        output_labels = graph.ndata['labels'][output_nodes]
        
        output_predictions = model(blocks, input_features)
        loss = F.cross_entropy(output_predictions, output_labels)
        opt.zero_grad()
        loss.backward()
        opt.step()

    if (epoch + 1) % 5 == 0:
        model.eval()
        all_predictions = inference_with_sagenet(model, graph, graph.ndata['features'], EVAL_BATCH_SIZE)

        val_predictions = all_predictions[val_nid]
        val_labels = graph.ndata['labels'][val_nid]
        test_predictions = all_predictions[test_nid]
        test_labels = graph.ndata['labels'][test_nid]

        print('Validation acc:', compute_accuracy(val_predictions, val_labels),
              'Test acc:', compute_accuracy(test_predictions, test_labels))