Skip to content

Commit

Permalink
Merge branch 'master' into untyped_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jan 10, 2024
2 parents 007ea93 + 04752e9 commit ef5d372
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 25 deletions.
90 changes: 90 additions & 0 deletions docs/source/graphtransformer/data.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
Prepare Data
============

In this section, we will prepare the data for the Graphormer model introduced before. We can use any dataset containing :class:`~dgl.DGLGraph` objects and standard PyTorch dataloader to feed the data to the model. The key is to define a collate function to group features of multiple graphs into batches. We show an example of the collate function as follows:


.. code:: python
def collate(graphs):
# compute shortest path features, can be done in advance
for g in graphs:
spd, path = dgl.shortest_dist(g, root=None, return_paths=True)
g.ndata["spd"] = spd
g.ndata["path"] = path
num_graphs = len(graphs)
num_nodes = [g.num_nodes() for g in graphs]
max_num_nodes = max(num_nodes)
attn_mask = th.zeros(num_graphs, max_num_nodes, max_num_nodes)
node_feat = []
in_degree, out_degree = [], []
path_data = []
# Since shortest_dist returns -1 for unreachable node pairs and padded
# nodes are unreachable to others, distance relevant to padded nodes
# use -1 padding as well.
dist = -th.ones(
(num_graphs, max_num_nodes, max_num_nodes), dtype=th.long
)
for i in range(num_graphs):
# A binary mask where invalid positions are indicated by True.
# Avoid the case where all positions are invalid.
attn_mask[i, :, num_nodes[i] + 1 :] = 1
# +1 to distinguish padded non-existing nodes from real nodes
node_feat.append(graphs[i].ndata["feat"] + 1)
# 0 for padding
in_degree.append(
th.clamp(graphs[i].in_degrees() + 1, min=0, max=512)
)
out_degree.append(
th.clamp(graphs[i].out_degrees() + 1, min=0, max=512)
)
# Path padding to make all paths to the same length "max_len".
path = graphs[i].ndata["path"]
path_len = path.size(dim=2)
# shape of shortest_path: [n, n, max_len]
max_len = 5
if path_len >= max_len:
shortest_path = path[:, :, :max_len]
else:
p1d = (0, max_len - path_len)
# Use the same -1 padding as shortest_dist for
# invalid edge IDs.
shortest_path = th.nn.functional.pad(path, p1d, "constant", -1)
pad_num_nodes = max_num_nodes - num_nodes[i]
p3d = (0, 0, 0, pad_num_nodes, 0, pad_num_nodes)
shortest_path = th.nn.functional.pad(shortest_path, p3d, "constant", -1)
# +1 to distinguish padded non-existing edges from real edges
edata = graphs[i].edata["feat"] + 1
# shortest_dist pads non-existing edges (at the end of shortest
# paths) with edge IDs -1, and th.zeros(1, edata.shape[1]) stands
# for all padded edge features.
edata = th.cat(
(edata, th.zeros(1, edata.shape[1]).to(edata.device)), dim=0
)
path_data.append(edata[shortest_path])
dist[i, : num_nodes[i], : num_nodes[i]] = graphs[i].ndata["spd"]
# node feat padding
node_feat = th.nn.utils.rnn.pad_sequence(node_feat, batch_first=True)
# degree padding
in_degree = th.nn.utils.rnn.pad_sequence(in_degree, batch_first=True)
out_degree = th.nn.utils.rnn.pad_sequence(out_degree, batch_first=True)
return (
node_feat,
in_degree,
out_degree,
attn_mask,
th.stack(path_data),
dist,
)
In this example, we also omit details like the addition of a virtual node. For more details, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer>`_.
12 changes: 12 additions & 0 deletions docs/source/graphtransformer/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
🆕 Tutorial: GraphTransformer
==========

This tutorial introduces the **graphtransformer** module, which is a set of
utility modules for building and training graph transformer models.

.. toctree::
:maxdepth: 2
:titlesonly:

model
data
89 changes: 89 additions & 0 deletions docs/source/graphtransformer/model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Build Model
===========

**GraphTransformer** is a graph neural network that uses multi-head self-attention (sparse or dense) to encode the graph structure and node features. It is a generalization of the `Transformer <https://arxiv.org/abs/1706.03762>`_ architecture to arbitrary graphs.

In this tutorial, we will show how to build a graph transformer model with DGL using the `Graphormer <https://arxiv.org/abs/2106.05234>`_ model as an example.

Graphormer is a Transformer model designed for graph-structured data, which encodes the structural information of a graph into the standard Transformer. Specifically, Graphormer utilizes degree encoding to measure the importance of nodes, spatial and path Encoding to measure the relation between node pairs. The degree encoding and the node features serve as input to Graphormer, while the spatial and path encoding act as bias terms in the self-attention module.

Degree Encoding
-------------------
The degree encoder is a learnable embedding layer that encodes the degree of each node into a vector. It takes as input the batched input and output degrees of graph nodes, and outputs the degree embeddings of the nodes.

.. code:: python
degree_encoder = dgl.nn.DegreeEncoder(
max_degree=8, # the maximum degree to cut off
embedding_dim=512 # the dimension of the degree embedding
)
Path Encoding
-------------
The path encoder encodes the edge features on the shortest path between two nodes to get attention bias for the self-attention module. It takes as input the batched edge features in shape and outputs the attention bias based on path encoding.

.. code:: python
path_encoder = PathEncoder(
max_len=5, # the maximum length of the shortest path
feat_dim=512, # the dimension of the edge feature
num_heads=8, # the number of attention heads
)
Spatial Encoding
----------------
The spatial encoder encodes the shortest distance between two nodes to get attention bias for the self-attention module. It takes as input the shortest distance between two nodes and outputs the attention bias based on spatial encoding.

.. code:: python
spatial_encoder = SpatialEncoder(
max_dist=5, # the maximum distance between two nodes
num_heads=8, # the number of attention heads
)
Graphormer Layer
----------------
The Graphormer layer is like a Transformer encoder layer with the Multi-head Attention part replaced with :class:`~dgl.nn.BiasedMHA`. It takes in not only the input node features, but also the attention bias computed computed above, and outputs the updated node features.

We can stack multiple Graphormer layers as a list just like implementing a Transformer encoder in PyTorch.

.. code:: python
layers = th.nn.ModuleList([
GraphormerLayer(
feat_size=512, # the dimension of the input node features
hidden_size=1024, # the dimension of the hidden layer
num_heads=8, # the number of attention heads
dropout=0.1, # the dropout rate
activation=th.nn.ReLU(), # the activation function
norm_first=False, # whether to put the normalization before attention and feedforward
)
for _ in range(6)
])
Model Forward
-------------
Grouping the modules above defines the primary components of the Graphormer model. We then can define the forward process as follows:

.. code:: python
node_feat, in_degree, out_degree, attn_mask, path_data, dist = \
next(iter(dataloader)) # we will use the first batch as an example
num_graphs, max_num_nodes, _ = node_feat.shape
deg_emb = degree_encoder(th.stack((in_degree, out_degree)))
# node feature + degree encoding as input
node_feat = node_feat + deg_emb
# spatial encoding and path encoding serve as attention bias
path_encoding = path_encoder(dist, path_data)
spatial_encoding = spatial_encoder(dist)
attn_bias[:, 1:, 1:, :] = path_encoding + spatial_encoding
# graphormer layers
for layer in layers:
x = layer(
x,
attn_mask=attn_mask,
attn_bias=attn_bias,
)
For simplicity, we omit some details in the forward process. For the complete implementation, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer`_.

You can also explore other `utility modules <https://docs.dgl.ai/api/python/nn-pytorch.html#utility-modules-for-graph-transformer>`_ to customize your own graph transformer model. In the next section, we will show how to prepare the data for training.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Welcome to Deep Graph Library Tutorials and Documentation
guide/index
guide_cn/index
guide_ko/index
graphtransformer/index
notebooks/sparse/index
tutorials/cpu/index
tutorials/multi/index
Expand Down
7 changes: 3 additions & 4 deletions graphbolt/src/fused_csc_sampling_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,8 @@ FusedCSCSamplingGraph::GetState() const {

c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::InSubgraph(
const torch::Tensor& nodes) const {
if (utils::is_accessible_from_gpu(indptr_) &&
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!type_per_edge_.has_value() ||
utils::is_accessible_from_gpu(type_per_edge_.value()))) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(c10::DeviceType::CUDA, "InSubgraph", {
Expand Down Expand Up @@ -616,9 +615,9 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
probs_or_mask = this->EdgeAttribute(probs_name);
}

if (!replace && utils::is_accessible_from_gpu(indptr_) &&
if (!replace && utils::is_on_gpu(nodes) &&
utils::is_accessible_from_gpu(indptr_) &&
utils::is_accessible_from_gpu(indices_) &&
utils::is_accessible_from_gpu(nodes) &&
(!probs_or_mask.has_value() ||
utils::is_accessible_from_gpu(probs_or_mask.value())) &&
(!type_per_edge_.has_value() ||
Expand Down
8 changes: 3 additions & 5 deletions graphbolt/src/index_select.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@ namespace graphbolt {
namespace ops {

torch::Tensor IndexSelect(torch::Tensor input, torch::Tensor index) {
if (input.is_pinned() &&
(index.is_pinned() || index.device().type() == c10::DeviceType::CUDA)) {
if (utils::is_on_gpu(index) && input.is_pinned()) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "UVAIndexSelect",
{ return UVAIndexSelectImpl(input, index); });
Expand All @@ -26,9 +25,8 @@ std::tuple<torch::Tensor, torch::Tensor> IndexSelectCSC(
torch::Tensor indptr, torch::Tensor indices, torch::Tensor nodes) {
TORCH_CHECK(
indices.sizes().size() == 1, "IndexSelectCSC only supports 1d tensors");
if (utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices) &&
utils::is_accessible_from_gpu(nodes)) {
if (utils::is_on_gpu(nodes) && utils::is_accessible_from_gpu(indptr) &&
utils::is_accessible_from_gpu(indices)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IndexSelectCSCImpl",
{ return IndexSelectCSCImpl(indptr, indices, nodes); });
Expand Down
3 changes: 1 addition & 2 deletions graphbolt/src/isin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ torch::Tensor IsInCPU(

torch::Tensor IsIn(
const torch::Tensor& elements, const torch::Tensor& test_elements) {
if (utils::is_accessible_from_gpu(elements) &&
utils::is_accessible_from_gpu(test_elements)) {
if (utils::is_on_gpu(elements) && utils::is_on_gpu(test_elements)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "IsInOperation",
{ return ops::IsIn(elements, test_elements); });
Expand Down
5 changes: 2 additions & 3 deletions graphbolt/src/unique_and_compact.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ namespace sampling {
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(
const torch::Tensor& src_ids, const torch::Tensor& dst_ids,
const torch::Tensor unique_dst_ids) {
if (utils::is_accessible_from_gpu(src_ids) &&
utils::is_accessible_from_gpu(dst_ids) &&
utils::is_accessible_from_gpu(unique_dst_ids)) {
if (utils::is_on_gpu(src_ids) && utils::is_on_gpu(dst_ids) &&
utils::is_on_gpu(unique_dst_ids)) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::UniqueAndCompact(src_ids, dst_ids, unique_dst_ids); });
Expand Down
9 changes: 8 additions & 1 deletion graphbolt/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
namespace graphbolt {
namespace utils {

/**
* @brief Checks whether the tensor is stored on the GPU.
*/
inline bool is_on_gpu(torch::Tensor tensor) {
return tensor.device().is_cuda();
}

/**
* @brief Checks whether the tensor is stored on the GPU or the pinned memory.
*/
inline bool is_accessible_from_gpu(torch::Tensor tensor) {
return tensor.is_pinned() || tensor.device().type() == c10::DeviceType::CUDA;
return is_on_gpu(tensor) || tensor.is_pinned();
}

/**
Expand Down
35 changes: 28 additions & 7 deletions python/dgl/graphbolt/impl/ondisk_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""GraphBolt OnDiskDataset."""

import os
import shutil
from copy import deepcopy
from typing import Dict, List, Union

Expand Down Expand Up @@ -34,7 +35,9 @@


def preprocess_ondisk_dataset(
dataset_dir: str, include_original_edge_id: bool = False
dataset_dir: str,
include_original_edge_id: bool = False,
force_preprocess: bool = False,
) -> str:
"""Preprocess the on-disk dataset. Parse the input config file,
load the data, and save the data in the format that GraphBolt supports.
Expand All @@ -45,6 +48,8 @@ def preprocess_ondisk_dataset(
The path to the dataset directory.
include_original_edge_id : bool, optional
Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional
Whether to force reload the ondisk dataset.
Returns
-------
Expand All @@ -62,13 +67,22 @@ def preprocess_ondisk_dataset(
)

# 0. Check if the dataset is already preprocessed.
preprocess_metadata_path = os.path.join("preprocessed", "metadata.yaml")
processed_dir_prefix = "preprocessed"
preprocess_metadata_path = os.path.join(
processed_dir_prefix, "metadata.yaml"
)
if os.path.exists(os.path.join(dataset_dir, preprocess_metadata_path)):
print("The dataset is already preprocessed.")
return os.path.join(dataset_dir, preprocess_metadata_path)
if force_preprocess:
shutil.rmtree(os.path.join(dataset_dir, processed_dir_prefix))
print(
"The on-disk dataset is re-preprocessing, so the existing "
+ "preprocessed dataset has been removed."
)
else:
print("The dataset is already preprocessed.")
return os.path.join(dataset_dir, preprocess_metadata_path)

print("Start to preprocess the on-disk dataset.")
processed_dir_prefix = "preprocessed"

# Check if the metadata.yaml exists.
metadata_file_path = os.path.join(dataset_dir, "metadata.yaml")
Expand Down Expand Up @@ -376,15 +390,22 @@ class OnDiskDataset(Dataset):
The YAML file path.
include_original_edge_id: bool, optional
Whether to include the original edge id in the FusedCSCSamplingGraph.
force_preprocess: bool, optional
Whether to force reload the ondisk dataset.
"""

def __init__(
self, path: str, include_original_edge_id: bool = False
self,
path: str,
include_original_edge_id: bool = False,
force_preprocess: bool = False,
) -> None:
# Always call the preprocess function first. If already preprocessed,
# the function will return the original path directly.
self._dataset_dir = path
yaml_path = preprocess_ondisk_dataset(path, include_original_edge_id)
yaml_path = preprocess_ondisk_dataset(
path, include_original_edge_id, force_preprocess
)
with open(yaml_path) as f:
self._yaml_data = yaml.load(f, Loader=yaml.loader.SafeLoader)
self._loaded = False
Expand Down

0 comments on commit ef5d372

Please sign in to comment.