# Stochastic Training of GNN for Node Classification on Large Graphs

*Note: this tutorial requires a GPU enabled machine*


This tutorial shows how to train a GraphSAGE model for node classification on the Amazon Co-purchase Network provided by OGB.

The dataset contains 2.4 million nodes and 61 million edges, hence the full graph will not fit in a single GPU.

By the end of this tutorial you will learn how to 

* Create a DGL graph from your own data in other formats such as CSV.
* Train a GNN model with a single machine, a single GPU, on a graph of any size.

## Load Dataset

Although you can directly use the Python package provided by OGB, for demonstration, we will instead manually download the dataset, peek into its contents, and process it with only `numpy`.

In [1]:
!wget https://snap.stanford.edu/ogb/data/nodeproppred/products.zip

--2021-01-14 01:09:50--  https://snap.stanford.edu/ogb/data/nodeproppred/products.zip
Resolving snap.stanford.edu (snap.stanford.edu)... 171.64.75.80
Connecting to snap.stanford.edu (snap.stanford.edu)|171.64.75.80|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1480993786 (1.4G) [application/zip]
Saving to: ‘products.zip.1’


2021-01-14 01:12:28 (8.98 MB/s) - ‘products.zip.1’ saved [1480993786/1480993786]



In [2]:
!unzip -o products.zip

Archive:  products.zip
  inflating: products/split/sales_ranking/test.csv.gz  
  inflating: products/split/sales_ranking/train.csv.gz  
  inflating: products/split/sales_ranking/valid.csv.gz  
  inflating: products/raw/node-label.csv.gz  
 extracting: products/raw/num-node-list.csv.gz  
 extracting: products/raw/num-edge-list.csv.gz  
  inflating: products/raw/node-feat.csv.gz  
  inflating: products/raw/edge.csv.gz  
  inflating: products/mapping/README.md  
 extracting: products/mapping/labelidx2productcategory.csv.gz  
  inflating: products/mapping/nodeidx2asin.csv.gz  
  inflating: products/RELEASE_v1.txt  


The dataset contains these files:

* `products/raw/edge.csv` (source-destination pairs)
* `products/raw/node-feat.csv` (node features)
* `products/raw/node-label.csv` (node labels)
* `products/raw/num-edge-list.csv` (number of edges)
* `products/raw/num-node-list.csv` (number of nodes)

We will only use the first three CSV files.

In addition, it also contains the following files defining the training-validation-test split in the directory `products/split/sales_ranking`.  All `train.csv`, `valid.csv` and `test.csv` are text files containing the node IDs in the training/validation/test set, one number per line.

In [3]:
import pandas as pd
edges = pd.read_csv('products/raw/edge.csv.gz', header=None).values
node_features = pd.read_csv('products/raw/node-feat.csv.gz', header=None).values
node_labels = pd.read_csv('products/raw/node-label.csv.gz', header=None).values[:, 0]

# pd.read_csv yields a DataFrame with one column, so we make them one-dimensional arrays.
train_nids = pd.read_csv('products/split/sales_ranking/train.csv.gz', header=None).values[:, 0]
valid_nids = pd.read_csv('products/split/sales_ranking/valid.csv.gz', header=None).values[:, 0]
test_nids = pd.read_csv('products/split/sales_ranking/test.csv.gz', header=None).values[:, 0]

### Loading Node IDs into DGL

<div class="alert alert-info">
    <b>Note:</b> The node IDs should be consecutive integers from 0 to the number of nodes minus 1.  If your node ID is not consecutive or not starting from 0 (e.g., starting from 100000), you need to relabel them yourself.  The <code>astype</code> method in pandas DataFrame can conveniently relabel the IDs by converting the type to <code>"category"</code>.
</div>

## Construct DGL Graph
We construct the graph as follows:

In [4]:
import dgl
import torch

graph = dgl.graph((edges[:, 0], edges[:, 1]))
node_features = torch.FloatTensor(node_features)
node_labels = torch.LongTensor(node_labels)

# Save the graph, features and training-validation-test split for use for future tutorials.
import pickle
with open('data.pkl', 'wb') as f:
    pickle.dump((graph, node_features, node_labels, train_nids, valid_nids, test_nids), f)

Using backend: pytorch


In [5]:
# Load the graph back from the file we saved
import dgl
import torch
import numpy as np
import pickle
with open('data.pkl', 'rb') as f:
    graph, node_features, node_labels, train_nids, valid_nids, test_nids = pickle.load(f)

We can see the size of the graph, features, and labels as follows.

In [6]:
print('Graph')
print(graph)
print('Shape of node features:', node_features.shape)
print('Shape of node labels:', node_labels.shape)

num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)

Graph
Graph(num_nodes=2449029, num_edges=61859140,
      ndata_schemes={}
      edata_schemes={})
Shape of node features: torch.Size([2449029, 100])
Shape of node labels: torch.Size([2449029])
Number of classes: 47


## Define a Data Loader with Neighbor Sampling

But first

### Message passing overview

The formulation of message passing usually has the following form:

$$
\begin{gathered}
  \boldsymbol{a}_v^{(l)} = \rho^{(l)} \left(
    \left\lbrace
      \boldsymbol{h}_u^{(l-1)} : u \in \mathcal{N} \left( v \right)
    \right\rbrace
  \right)
\\
  \boldsymbol{h}_v^{(l)} = \phi^{(l)} \left(
    \boldsymbol{h}_v^{(l-1)}, \boldsymbol{a}_v^{(l)}
  \right)
\end{gathered}
$$

where $\rho^{(l)}$ and $\phi^{(l)}$ are parameterized functions, and $\mathcal{N}(v)$ represents the set of predecessors (or equivalently *neighbors*) of $v$ on graph $\mathcal{G}$:
$$
\mathcal{N} \left( v \right) = \left\lbrace
  s \left( e \right) : e \in \mathbb{E}, t \left( e \right) = v
\right\rbrace
$$


For instance, to perform a message passing for updating the red node in the following graph:

![Imgur](assets/1.png)

You need to aggregate the node features of its neighbors, shown as green nodes:

![Imgur](assets/2.png)

Let's consider how multi-layer message passing works for computing the output of a single node.  In the following text, we refer to the nodes whose outputs GNN will compute as seed nodes.


### Multi-layer message passing 

Consider computing with a 2-layer GNN the output of the seed node 8, colored red, in the following graph:

![Imgur](assets/seed.png)

By the formulation:

$$
\begin{gathered}
  \boldsymbol{a}_8^{(2)} = \rho^{(2)} \left(
    \left\lbrace
      \boldsymbol{h}_u^{(1)} : u \in \mathcal{N} \left( 8 \right)
    \right\rbrace
  \right) = \rho^{(2)} \left(
    \left\lbrace
      \boldsymbol{h}_4^{(1)}, \boldsymbol{h}_5^{(1)},
      \boldsymbol{h}_7^{(1)}, \boldsymbol{h}_{11}^{(1)}
    \right\rbrace
  \right)
\\
  \boldsymbol{h}_8^{(2)} = \phi^{(2)} \left(
    \boldsymbol{h}_8^{(1)}, \boldsymbol{a}_8^{(2)}
  \right)
\end{gathered}
$$

We can tell that, to compute $\boldsymbol{h}_8^{(2)}$, we need messages from node 4, 5, 7, and 11 (colored green) along the edges visualized below.

![Imgur](assets/3.png)

The values of $\boldsymbol{h}_\cdot^{(1)}$ are the outputs from the first GNN layer.

To compute those values for the red and green nodes, we further need to perform message passing on the edges visualized below.

![Imgur](assets/4.png)

Therefore, to compute the 2-layer GNN representation of the red node, we need the input features from the red node as well as the green and yellow nodes.  Note that we should take red node's neighbors again for this layer.

You may notice that the procedure which determines computation dependency is in the reverse direction of message aggregation: you start from the layer closest to the output and work backward to the input.

Summary
* Computing representation for a small number of nodes still often requires input features of a significantly larger number of nodes.  

* Taking all neighbors for message aggregation is often too costly since the nodes needed would easily cover a large portion of the graph.

Neighbour sampling addresses this

### Neighbour Sampling Overview

Neighbor sampling addresses this issue by selecting a random subset of the neighbors to perform aggregation.

For example, to compute $\boldsymbol{h}_8^{(1)}$, we can choose to sample 2 neighbors and aggregate.

![Imgur](assets/5.png)

Similarly, to compute the red and green nodes' first layer representation, we can also do neighbor sampling that takes 2 neighbors for each node.  Note that we should take the red node's neighbors again for this layer.

![Imgur](assets/6.png)

You can see that this method could give us fewer nodes needed for input features.

### Other graph sampling strategies
* Neighborhood sampling (GraphSAGE)
* Control-variate-based sampling (VRGCN)
* Layer-wise sampling (FastGCN, LADIES)
* Random-walk-based sampling (PinSage)
* Subgraph sampling (ClusterGCN, GraphSAINT)


### Defining neighbor sampler and node data loader in DGL



DGL provides useful tools to generate such computation dependencies while iterating over the dataset in minibatches and performing neighbor sampling.

For node classification, you can use
* `dgl.dataloading.NodeDataLoader` for iterating over the dataset, and
* `dgl.dataloading.MultiLayerNeighborSampler` to generate computation dependencies of the nodes with neighbor sampling.

The syntax of `dgl.dataloading.NodeDataLoader` is mostly similar to a PyTorch `DataLoader`, with the addition that it needs a graph to generate computation dependency from, a set of node IDs to iterate on, and the neighbor sampler you defined.

Let's consider training a 3-layer GraphSAGE with neighbor sampling, and each node will gather message from 4 neighbors on each layer.  The code defining the data loader and neighbor sampler will look like the following.

In [7]:
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
    graph, train_nids, sampler,
    batch_size=1024,
    shuffle=True,
    drop_last=False,
    num_workers=0
)

We can peek at the first item in the data loader we created and see what it gives us.

In [8]:
example_minibatch = next(iter(train_dataloader))
print(example_minibatch)

(tensor([185967,  29396,  35032,  ...,  10059,  15183,    676]), tensor([185967,  29396,  35032,  ..., 134351,  60992,  40989]), [Block(num_src_nodes=35310, num_dst_nodes=16046, num_edges=52314), Block(num_src_nodes=16046, num_dst_nodes=4632, num_edges=16299), Block(num_src_nodes=4632, num_dst_nodes=1024, num_edges=3759)])


`NodeDataLoader` gives us three items per iteration.

* The input node list for the nodes whose input features are needed to compute the outputs.
* The output node list whose GNN representation are to be computed.
* The list of computation dependency for each layer.

In [9]:
input_nodes, output_nodes, bipartites = example_minibatch
print("To compute {} nodes' output we need {} nodes' input features".format(len(output_nodes), len(input_nodes)))

To compute 1024 nodes' output we need 35310 nodes' input features


The variable `bipartites` has the message passing computation dependency for each layer.

It is named suggestively, because it can be thought of as a **list** of bipartite graphs.

So why does DGL return a list of *bipartite* graphs for training a *homogeneous* graph? 

To distinguish between the source nodes sending the messages and the destination nodes being updated at each layer.

Recall the sampled sub-graph from the example above:

![Imgur](assets/6.png)


The first GNN layer outputs the representation of three nodes (two green nodes and one red node), but requires input from 7 nodes (the green nodes and red node, plus 4 yellow nodes).  

A bipartite graph easily captures the computation dependency?:

![](assets/bipartite.png)

Let's look at each *bipartite* graph in `bipartites`

In [10]:
for block in bipartites:
    print(block)

Block(num_src_nodes=35310, num_dst_nodes=16046, num_edges=52314)
Block(num_src_nodes=16046, num_dst_nodes=4632, num_edges=16299)
Block(num_src_nodes=4632, num_dst_nodes=1024, num_edges=3759)


Minibatch training of GNNs usually involves message passing on such bipartite graphs.

## Defining Model

We are training a GraphSage GNN model that was previously introduced.

The model can be written as follows:

In [11]:
import torch.nn as nn
import torch.nn.functional as F
import dgl.nn as dglnn

class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        
    def forward(self, bipartites, x):
        for l, (layer, bipartite) in enumerate(zip(self.layers, bipartites)):
            x = layer(bipartite, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

You can see that here we are iterating over the pairs of NN module layer and bipartite graphs generated by the data loader.

## Defining the Training Loop


The following initializes the model and defines the optimizer.

In [12]:
model = SAGE(num_features, 128, num_classes, 3).cuda()
opt = torch.optim.Adam(model.parameters())

### Dataloader for validation

When computing the validation score for model selection, usually you can also do neighbor sampling.  To do that, you need to define another data loader.

In [13]:
valid_dataloader = dgl.dataloading.NodeDataLoader(
    graph, valid_nids, sampler,
    batch_size=1024,
    shuffle=False,
    drop_last=False,
    num_workers=0
)

The following is a training loop that performs validation every epoch.  It also saves the model with the best validation accuracy into a file.

In [14]:
import tqdm
import sklearn.metrics

best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
    model.train()
    
    with tqdm.tqdm(train_dataloader) as tq:
        for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            labels = node_labels[output_nodes].cuda()
            predictions = model(bipartites, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
            
            tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
        
    model.eval()
    
    predictions = []
    labels = []
    with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
        for input_nodes, output_nodes, bipartites in tq:
            bipartites = [b.to(torch.device('cuda')) for b in bipartites]
            inputs = node_features[input_nodes].cuda()
            labels.append(node_labels[output_nodes].numpy())
            predictions.append(model(bipartites, inputs).argmax(1).cpu().numpy())
        predictions = np.concatenate(predictions)
        labels = np.concatenate(labels)
        accuracy = sklearn.metrics.accuracy_score(labels, predictions)
        print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
        if best_accuracy < accuracy:
            best_accuracy = accuracy
            torch.save(model.state_dict(), best_model_path)

100%|██████████| 193/193 [00:07<00:00, 24.35it/s, loss=0.845, acc=0.571]
100%|██████████| 39/39 [00:01<00:00, 24.39it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.91it/s, loss=0.780, acc=0.812]

Epoch 0 Validation Accuracy 0.8259288457137044


100%|██████████| 193/193 [00:07<00:00, 24.97it/s, loss=0.693, acc=0.714]
100%|██████████| 39/39 [00:01<00:00, 24.48it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.54it/s, loss=0.664, acc=0.827]

Epoch 1 Validation Accuracy 0.8476464155837551


100%|██████████| 193/193 [00:07<00:00, 25.05it/s, loss=1.057, acc=0.714]
100%|██████████| 39/39 [00:01<00:00, 24.30it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.87it/s, loss=0.674, acc=0.823]

Epoch 2 Validation Accuracy 0.8549195127533504


100%|██████████| 193/193 [00:07<00:00, 25.06it/s, loss=0.176, acc=0.857]
100%|██████████| 39/39 [00:01<00:00, 24.41it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.20it/s, loss=0.515, acc=0.862]

Epoch 3 Validation Accuracy 0.8643796251557613


100%|██████████| 193/193 [00:07<00:00, 24.96it/s, loss=0.414, acc=0.857]
100%|██████████| 39/39 [00:01<00:00, 24.47it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.53it/s, loss=0.546, acc=0.848]

Epoch 4 Validation Accuracy 0.8671515398113063


100%|██████████| 193/193 [00:07<00:00, 24.76it/s, loss=0.533, acc=0.857]
100%|██████████| 39/39 [00:01<00:00, 24.04it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.53it/s, loss=0.473, acc=0.886]

Epoch 5 Validation Accuracy 0.871296696589782


100%|██████████| 193/193 [00:07<00:00, 24.53it/s, loss=1.526, acc=0.571]
100%|██████████| 39/39 [00:01<00:00, 24.46it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.32it/s, loss=0.579, acc=0.830]

Epoch 6 Validation Accuracy 0.8660326017852148


100%|██████████| 193/193 [00:07<00:00, 24.74it/s, loss=0.020, acc=1.000]
100%|██████████| 39/39 [00:01<00:00, 24.35it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.36it/s, loss=0.460, acc=0.869]

Epoch 7 Validation Accuracy 0.8780866159753834


100%|██████████| 193/193 [00:07<00:00, 24.52it/s, loss=0.259, acc=1.000]
100%|██████████| 39/39 [00:01<00:00, 24.26it/s]
  2%|▏         | 3/193 [00:00<00:08, 22.70it/s, loss=0.456, acc=0.875]

Epoch 8 Validation Accuracy 0.8790021107240038


100%|██████████| 193/193 [00:07<00:00, 24.43it/s, loss=0.047, acc=1.000]
100%|██████████| 39/39 [00:01<00:00, 24.26it/s]

Epoch 9 Validation Accuracy 0.8800956183404115





## Offline Inference without Neighbor Sampling


Usually for offline inference it is desirable to aggregate over the entire neighborhood to eliminate randomness introduced by neighbor sampling.  However, using the same methodology in training is not efficient, because there will be a lot of redundant computation.  Moreover, simply doing neighbor sampling by taking all neighbors will often exhaust GPU memory because the number of nodes required for input features may be too large to fit into GPU memory.

Instead, you need to compute the representations layer by layer: you first compute the output of the first GNN layer for all nodes, then you compute the output of second GNN layer for all nodes using the first GNN layer's output as input, etc.  This gives us a different algorithm from what is being used in training.  During training we have an outer loop that iterates over the nodes, and an inner loop that iterates over the layers.  In contrast, during inference we have an outer loop that iterates over the layers, and an inner loop that iterates over the nodes.

If you do not care about randomness too much (e.g., during model selection in validation), you can still use the `dgl.dataloading.MultiLayerNeighborSampler` and `dgl.dataloading.NodeDataLoader` to do offline inference, since it is usually faster for evaluating a small number of nodes.

![Imgur](assets/anim.gif)

In [15]:
def inference(model, graph, input_features, batch_size):
    nodes = torch.arange(graph.number_of_nodes())
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([None])  # one layer at a time, taking all neighbors
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nodes, sampler,
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
        num_workers=0)
    
    with torch.no_grad():
        for l, layer in enumerate(model.layers):
            # Allocate a buffer of output representations for every node
            # Note that the buffer is on CPU memory.
            output_features = torch.zeros(
                graph.number_of_nodes(), model.n_hidden if l != model.n_layers - 1 else model.n_classes)

            for input_nodes, output_nodes, bipartites in tqdm.tqdm(dataloader):
                bipartite = bipartites[0].to(torch.device('cuda'))

                x = input_features[input_nodes].cuda()

                # the following code is identical to the loop body in model.forward()
                x = layer(bipartite, x)
                if l != model.n_layers - 1:
                    x = F.relu(x)

                output_features[output_nodes] = x.cpu()
            input_features = output_features
    return output_features

The following code loads the best model from the file saved previously and performs offline inference.  It computes the accuracy on the test set afterwards.

In [16]:
model.load_state_dict(torch.load(best_model_path))
all_predictions = inference(model, graph, node_features, 8192)

100%|██████████| 299/299 [00:34<00:00,  8.56it/s]
100%|██████████| 299/299 [00:26<00:00, 11.19it/s]
100%|██████████| 299/299 [00:26<00:00, 11.29it/s]


In [17]:
test_predictions = all_predictions[test_nids].argmax(1)
test_labels = node_labels[test_nids]
test_accuracy = sklearn.metrics.accuracy_score(test_predictions.numpy(), test_labels.numpy())
print('Test accuracy:', test_accuracy)

Test accuracy: 0.7298741895385232


## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE with neighbor sampling on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with a single GPU.

## What's next?

The next tutorial will be about training the same GraphSAGE model in an unsupervised manner with link prediction, i.e. predicting whether an edge exist between two nodes.