Skip to content

v0.9.0

Compare
Choose a tag to compare
@BarclayII BarclayII released this 18 Jul 15:43
· 91 commits to 0.9.x since this release
c7edb66

This is a major update with several new features including graph prediction pipeline in DGL-Go, cuGraph support, mixed precision support, and more.

Starting from 0.9 we also ship arm64 builds for Linux and OSX.

DGL-Go

DGL-Go now supports training GNNs for graph property prediction tasks. It includes two popular GNN models – Graph Isomorphism Network (GIN) and Principal Neighborhood Aggregation (PNA). For example, to train a GIN model on the ogbg-molpcba dataset, first generate a YAML configuration file using command:

dgl configure graphpred --data ogbg-molpcba --model gin

which generates the following configuration file. Users can then manually adjust the configuration file.

version: 0.0.2
pipeline_name: graphpred
pipeline_mode: train
device: cpu                     # Torch device name, e.g., cpu or cuda or cuda:0
data:
    name: ogbg-molpcba
    split_ratio:                # Ratio to generate data split, for example set to [0.8, 0.1, 0.1] for 80% train/10% val/10% test. Leave blank to use builtin split in original dataset
model:
     name: gin
     embed_size: 300            # Embedding size
     num_layers: 5              # Number of layers
     dropout: 0.5               # Dropout rate
     virtual_node: false        # Whether to use virtual node
general_pipeline:
    num_runs: 1                 # Number of experiments to run
    train_batch_size: 32        # Graph batch size when training
    eval_batch_size: 32         # Graph batch size when evaluating
    num_workers: 4              # Number of workers for data loading
    optimizer:
        name: Adam
        lr: 0.001
        weight_decay: 0
    lr_scheduler:
        name: StepLR
        step_size: 100
        gamma: 1
    loss: BCEWithLogitsLoss
    metric: roc_auc_score
    num_epochs: 100             # Number of training epochs
    save_path: results          # Directory to save the experiment results

Alternatively, users can fetch model recipes of pre-defined hyperparameters for the original experiments.

dgl recipe get graphpred_pcba_gin.yaml

To launch training:

dgl train --cfg graphpred_ogbg-molpcba_gin.yaml

Another addition is a new command to conduct inference of a trained model on some other dataset. For example, the following shows how to apply the GIN model trained on ogbg-molpcba to ogbg-molhiv.

# Generate an inference configuration file from a saved experiment checkpoint
dgl configure-apply graphpred --data ogbg-molhiv --cpt results/run_0.pth

# Apply the trained model for inference
dgl apply --cfg apply_graphpred_ogbg-molhiv_pna.yaml

It will save the model prediction in a CSV file like below
image

Mixed Precision

DGL is compatible with the PyTorch Automatic Mixed Precision (AMP) package for mixed precision training, thus saving both training time and GPU memory consumption. This feature requires PyTorch 1.6+ and Python 3.7+.

By wrapping the forward pass with torch.cuda.amp.autocast(), PyTorch automatically selects the appropriate data type for each op and tensor. Half precision tensors are memory efficient, most operators on half precision tensors are faster as they leverage GPU tensorcores.

import torch.nn.functional as F
from torch.cuda.amp import autocast

def forward(g, feat, label, mask, model):
      with autocast(enabled=True):
            logit = model(g, feat)
            loss = F.cross_entropy(logit[mask], label[mask])
            return loss

Small gradients in float16 format have underflow problems (flush to zero). PyTorch provides a GradScaler module to address this issue. It multiplies the loss by a factor and invokes backward pass on the scaled loss to prevent the underflow problem. It then unscales the computed gradients before the optimizer updates the parameters. The scale factor is determined automatically.

from torch.cuda.amp import GradScaler

scaler = GradScaler()

def backward(scaler, loss, optimizer):
      scaler.scale(loss).backward()
      scaler.step(optimizer)
      scaler.update()

Putting everything together, we have the example below.

import torch
import torch.nn as nn
from dgl.data import RedditDataset
from dgl.nn import GATConv
from dgl.transforms import AddSelfLoop

class GAT(nn.Module):
      def __init__(self, in_feats, num_classes, num_hidden=256, num_heads=2):
            super().__init__()
            self.conv1 = GATConv(in_feats, num_hidden, num_heads, activation=F.elu)
            self.conv2 = GATConv(num_hidden * num_heads, num_hidden, num_heads)

      def forward(self, g, h):
            h = self.conv1(g, h).flatten(1)
            h = self.conv2(g, h).mean(1)
            return h

device = torch.device('cuda')

transform = AddSelfLoop()
data = RedditDataset(transform)

g = data[0]
g = g.int().to(device)
train_mask = g.ndata['train_mask']
feat = g.ndata['feat']
label = g.ndata['label']
in_feats = feat.shape[1]

model = GAT(in_feats, data.num_classes).to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

for epoch in range(100):
     optimizer.zero_grad()
     loss = forward(g, feat, label, train_mask, model)
     backward(scaler, loss, optimizer)

Thanks @nv-dlasalle @ndickson-nvidia @yaox12 etc. for support!

cuGraph Interface

The RAPIDS cuGraph library provides a collection of GPU accelerated algorithms for graph analytics, such as centrality computation and community detection. According to its documentation, “the latest NVIDIA GPUs (RAPIDS supports Pascal and later GPU architectures) make graph analytics 1000x faster on average over NetworkX”.

To install cuGraph, we recommend following the practice below.

conda install mamba -n base -c conda-forge

mamba create -n dgl_and_cugraph -c dglteam -c rapidsai-nightly -c nvidia -c pytorch -c conda-forge cugraph pytorch torchvision torchaudio cudatoolkit=11.3 dgl-cuda11.3 tqdm

conda activate dgl_and_cugraph

DGL now supports compatibility with cuGraph by allowing conversion between a DGLGraph object and a cuGraph graph object, making it possible for DGL users to access efficient graph analytics implementations in cuGraph. For example, users can perform community detection on a graph with the Louvain method available in cuGraph.

import cugraph

from dgl.data import CoraGraphDataset

dataset = CoraGraphDataset()
g = dataset[0].to('cuda')
cugraph_g = g.to_cugraph()
cugraph_g = cugraph_g.to_undirected()
parts, modularity_score = cugraph.louvain(cugraph_g)

The community membership of nodes from parts['partition'] can then be used as auxiliary node labels or node features.

If you have modified the structure of a cuGraph graph object or loaded graph data with cuGraph, you can also convert it to a DGLGraph object.

import dgl
g = dgl.from_cugraph(cugraph_g)

Credits to @VibhuJawa!

Arm64 builds

Linux AArch64 and OSX M1 (arm64) are now supported. One can install them as usual with pip and conda:

pip install dgl-cuXX -f https://data.dgl.ai/wheels/repo.html
conda install -c dglteam dgl-cudaXX.X   # currently not available for OSX M1

Quality-of-life updates

System optimizations

  • Enable using UVA and FP16 with SparseAdam Optimizer (#3885, @nv-dlasalle )
  • Enable USE_EPOLL by default in distributed training (#4167)
  • Optimize the use of alternative streams in dataloader (#4177, @yaox12 )
  • Redirect AllocWorkspace to PyTorch's allocator if available (#4199, @yaox12 )

Bug fixes

Misc

  • Test pipeline for distributed training (#4122 , @Kh4L)