# Stochastic Training of GNN with Multiple GPUs

In this tutorial you will learn how to train a multi-layer GraphSAGE for node classification on Amazon Copurchase Network provided by OGB with multiple GPUs.  The dataset contains 2.4 million nodes and 61 million edges, hence not able to fit in a single GPU.

The contents in this tutorial include how to

* Train a GNN model with a single machine with multiple GPUs on a graph of any size with `torch.nn.parallel.DistributedDataParallel`.

PyTorch `DistributedDataParallel` (or DDP in short) is a common solution for multi-GPU training.  It is easy to combine DGL with PyTorch DDP, as you do the same thing as that in any ordinary PyTorch applications:

* Divide the data to each GPU.
* Distribute the model parameters using PyTorch DDP.

In [1]:
import numpy as np
import dgl
import torch
import dgl.nn as dglnn
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel
import torch.nn.functional as F
import torch.multiprocessing as mp
import sklearn.metrics
import tqdm

import utils

Using backend: pytorch


## Load Dataset

The following code is copied from the first tutorial.

In [2]:
def load_data():
    import pickle

    with open('data.pkl', 'rb') as f:
        data = pickle.load(f)
    graph, node_features, node_labels, train_nids, valid_nids, test_nids = data
    utils.prepare_mp(graph)
    
    num_features = node_features.shape[1]
    num_classes = (node_labels.max() + 1).item()
    
    return graph, node_features, node_labels, train_nids, valid_nids, test_nids, num_features, num_classes

## Defining Data Loader for Distributed Data Parallel (DDP)

In PyTorch DDP each worker process is assigned an integer *rank*.  The rank would indicate which partition of the dataset the worker process will handle.  So the only difference between single GPU and multiple GPU training in terms of data loader is that the data loader will only iterate over a partition of the nodes.

In [3]:
def create_dataloader(rank, world_size, graph, nids):
    partition_size = len(nids) // world_size
    partition_offset = partition_size * rank
    nids = nids[partition_offset:partition_offset+partition_size]
    
    sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4, 4])
    dataloader = dgl.dataloading.NodeDataLoader(
        graph, nids, sampler,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=0
    )
    
    return dataloader

## Defining Model

The model implementation will be exactly the same as what you have seen in the first tutorial.

In [4]:
class SAGE(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layers = nn.ModuleList()
        self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
        for i in range(1, n_layers - 1):
            self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
        self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
        
    def forward(self, blocks, x):
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            x = layer(block, x)
            if l != self.n_layers - 1:
                x = F.relu(x)
        return x

## Distributing the Model to GPUs

PyTorch DDP manages the distribution of models and synchronization of the gradients for you.  In DGL, you can benefit from PyTorch DDP as well by simply wrapping the model with `torch.nn.parallel.DistributedDataParallel`.

The recommended way to distribute training is to have one training process per GPU, so during model instantiation we also specify the process rank, which is equal to the GPU ID.

In [5]:
def init_model(rank, in_feats, n_hidden, n_classes, n_layers):
    model = SAGE(in_feats, n_hidden, n_classes, n_layers).to(rank)
    return DistributedDataParallel(model, device_ids=[rank], output_device=rank)

## The Training Loop for one Process

The training loop looks the same as other PyTorch DDP applications.

In [6]:
@utils.fix_openmp
def train(rank, world_size, data):
    # data is the output of load_data
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='tcp://127.0.0.1:12345',
        world_size=world_size,
        rank=rank)
    torch.cuda.set_device(rank)
    
    graph, node_features, node_labels, train_nids, valid_nids, test_nids, num_features, num_classes = data
    
    train_dataloader = create_dataloader(rank, world_size, graph, train_nids)
    # We only use one worker for validation
    valid_dataloader = create_dataloader(0, 1, graph, valid_nids)
    
    model = init_model(rank, num_features, 128, num_classes, 3)
    opt = torch.optim.Adam(model.parameters())
    torch.distributed.barrier()
    
    best_accuracy = 0
    best_model_path = 'model.pt'
    for epoch in range(20):
        model.train()

        for step, (input_nodes, output_nodes, blocks) in enumerate(train_dataloader):
            inputs = node_features[input_nodes].cuda()
            labels = node_labels[output_nodes].cuda()
            predictions = model(blocks, inputs)

            loss = F.cross_entropy(predictions, labels)
            opt.zero_grad()
            loss.backward()
            opt.step()

            accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())

            if rank == 0 and step % 10 == 0:
                print('Epoch {:05d} Step {:05d} Loss {:.04f}'.format(epoch, step, loss.item()))

        torch.distributed.barrier()
        
        if rank == 0:
            model.eval()
            predictions = []
            labels = []
            with tqdm.tqdm_notebook(valid_dataloader) as tq, torch.no_grad():
                for input_nodes, output_nodes, blocks in tq:
                    inputs = node_features[input_nodes].cuda()
                    labels.append(node_labels[output_nodes].numpy())
                    predictions.append(model.module(blocks, inputs).argmax(1).cpu().numpy())
                predictions = np.concatenate(predictions)
                labels = np.concatenate(labels)
                accuracy = sklearn.metrics.accuracy_score(labels, predictions)
                print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
                if best_accuracy < accuracy:
                    best_accuracy = accuracy
                    torch.save(model.module.state_dict(), best_model_path)
                    
        torch.distributed.barrier()

In [None]:
if __name__ == '__main__':
    procs = []
    data = load_data()
    for proc_id in range(4):    # 4 gpus
        p = mp.Process(target=train, args=(proc_id, 4, data))
        p.start()
        procs.append(p)
    for p in procs:
        p.join()

['coo', 'csr', 'csc']
Epoch 00000 Step 00000 Loss 6.8194
Epoch 00000 Step 00010 Loss 2.7267
Epoch 00000 Step 00020 Loss 2.2743
Epoch 00000 Step 00030 Loss 1.8698
Epoch 00000 Step 00040 Loss 1.5360


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 0 Validation Accuracy 0.7069145283930525
Epoch 00001 Step 00000 Loss 1.5336
Epoch 00001 Step 00010 Loss 1.3781
Epoch 00001 Step 00020 Loss 1.3528
Epoch 00001 Step 00030 Loss 1.3083
Epoch 00001 Step 00040 Loss 1.2994


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 1 Validation Accuracy 0.791190906085497
Epoch 00002 Step 00000 Loss 1.1379
Epoch 00002 Step 00010 Loss 1.0715
Epoch 00002 Step 00020 Loss 1.0465
Epoch 00002 Step 00030 Loss 1.0820
Epoch 00002 Step 00040 Loss 0.9921


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 2 Validation Accuracy 0.8177911146148564
Epoch 00003 Step 00000 Loss 1.0641
Epoch 00003 Step 00010 Loss 1.0089
Epoch 00003 Step 00020 Loss 0.9286
Epoch 00003 Step 00030 Loss 0.9309
Epoch 00003 Step 00040 Loss 1.0498


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 3 Validation Accuracy 0.8342954504997075
Epoch 00004 Step 00000 Loss 1.0211
Epoch 00004 Step 00010 Loss 0.9226
Epoch 00004 Step 00020 Loss 0.8738
Epoch 00004 Step 00030 Loss 0.8745
Epoch 00004 Step 00040 Loss 0.9150


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 4 Validation Accuracy 0.844365892734532
Epoch 00005 Step 00000 Loss 0.8598
Epoch 00005 Step 00010 Loss 0.8585
Epoch 00005 Step 00020 Loss 0.8589
Epoch 00005 Step 00030 Loss 0.7234
Epoch 00005 Step 00040 Loss 0.8061


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 5 Validation Accuracy 0.8473412506675483
Epoch 00006 Step 00000 Loss 0.8904
Epoch 00006 Step 00010 Loss 0.7703
Epoch 00006 Step 00020 Loss 0.8840
Epoch 00006 Step 00030 Loss 0.8748
Epoch 00006 Step 00040 Loss 0.8140


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 6 Validation Accuracy 0.8504437606489841
Epoch 00007 Step 00000 Loss 0.7839
Epoch 00007 Step 00010 Loss 0.8774
Epoch 00007 Step 00020 Loss 0.7986
Epoch 00007 Step 00030 Loss 0.7188
Epoch 00007 Step 00040 Loss 0.8071


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 7 Validation Accuracy 0.851206672939501
Epoch 00008 Step 00000 Loss 0.7798
Epoch 00008 Step 00010 Loss 0.7694
Epoch 00008 Step 00020 Loss 0.7875
Epoch 00008 Step 00030 Loss 0.7783
Epoch 00008 Step 00040 Loss 0.6415


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 8 Validation Accuracy 0.8576659969992116
Epoch 00009 Step 00000 Loss 0.7488
Epoch 00009 Step 00010 Loss 0.7146
Epoch 00009 Step 00020 Loss 0.7745
Epoch 00009 Step 00030 Loss 0.7785
Epoch 00009 Step 00040 Loss 0.7097


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 9 Validation Accuracy 0.859624138544872
Epoch 00010 Step 00000 Loss 0.7197
Epoch 00010 Step 00010 Loss 0.8006
Epoch 00010 Step 00020 Loss 0.6581
Epoch 00010 Step 00030 Loss 0.7120
Epoch 00010 Step 00040 Loss 0.7275


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 10 Validation Accuracy 0.8615822800905323
Epoch 00011 Step 00000 Loss 0.7184
Epoch 00011 Step 00010 Loss 0.7176
Epoch 00011 Step 00020 Loss 0.6674
Epoch 00011 Step 00030 Loss 0.6842
Epoch 00011 Step 00040 Loss 0.6612


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 11 Validation Accuracy 0.86267578770694
Epoch 00012 Step 00000 Loss 0.7541
Epoch 00012 Step 00010 Loss 0.7052
Epoch 00012 Step 00020 Loss 0.6887
Epoch 00012 Step 00030 Loss 0.6068
Epoch 00012 Step 00040 Loss 0.6764


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 12 Validation Accuracy 0.8621671795132619
Epoch 00013 Step 00000 Loss 0.7154
Epoch 00013 Step 00010 Loss 0.7305
Epoch 00013 Step 00020 Loss 0.6302
Epoch 00013 Step 00030 Loss 0.6839
Epoch 00013 Step 00040 Loss 0.6643


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 13 Validation Accuracy 0.8561147420084938
Epoch 00014 Step 00000 Loss 0.6759
Epoch 00014 Step 00010 Loss 0.7228
Epoch 00014 Step 00020 Loss 0.7989
Epoch 00014 Step 00030 Loss 0.7027
Epoch 00014 Step 00040 Loss 0.7883


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 14 Validation Accuracy 0.8631843959006179
Epoch 00015 Step 00000 Loss 0.6412
Epoch 00015 Step 00010 Loss 0.6829
Epoch 00015 Step 00020 Loss 0.6613
Epoch 00015 Step 00030 Loss 0.6583
Epoch 00015 Step 00040 Loss 0.6196


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 15 Validation Accuracy 0.8673804134984615
Epoch 00016 Step 00000 Loss 0.6129
Epoch 00016 Step 00010 Loss 0.6114
Epoch 00016 Step 00020 Loss 0.6273
Epoch 00016 Step 00030 Loss 0.6406
Epoch 00016 Step 00040 Loss 0.5958


HBox(children=(FloatProgress(value=0.0, max=39.0), HTML(value='')))


Epoch 16 Validation Accuracy 0.8676855784146683
Epoch 00017 Step 00000 Loss 0.6194
Epoch 00017 Step 00010 Loss 0.6134
Epoch 00017 Step 00020 Loss 0.6693
Epoch 00017 Step 00030 Loss 0.6642


## Conclusion

In this tutorial, you have learned how to train a multi-layer GraphSAGE for node classification on a large dataset that cannot fit into GPU.  The method you have learned can scale to a graph of any size, and works on a single machine with *any number of* GPU.