Skip to content

Commit

Permalink
Merge branch 'master' into spot_target
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jan 17, 2024
2 parents 6612adf + 053c822 commit 49ee714
Show file tree
Hide file tree
Showing 52 changed files with 1,659 additions and 787 deletions.
2 changes: 1 addition & 1 deletion conda/dgl/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package:
name: dgl{{ environ.get('DGL_PACKAGE_SUFFIX', '') }}
version: 2.0{{ environ.get('DGL_VERSION_SUFFIX', '') }}
version: 2.1{{ environ.get('DGL_VERSION_SUFFIX', '') }}

source:
git_rev: {{ environ.get('DGL_RELEASE_BRANCH', 'master') }}
Expand Down
1 change: 1 addition & 0 deletions docs/source/graphtransformer/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ In this section, we will prepare the data for the Graphormer model introduced be


.. code:: python
def collate(graphs):
# compute shortest path features, can be done in advance
for g in graphs:
Expand Down
6 changes: 3 additions & 3 deletions docs/source/graphtransformer/index.rst
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
🆕 Tutorial: GraphTransformer
🆕 Tutorial: Graph Transformer
==========

This tutorial introduces the **graphtransformer** module, which is a set of
utility modules for building and training graph transformer models.
This tutorial introduces the **graph transformer** (:mod:`~dgl.nn.gt`) module,
which is a set of utility modules for building and training graph transformer models.

.. toctree::
:maxdepth: 2
Expand Down
7 changes: 6 additions & 1 deletion docs/source/graphtransformer/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Degree Encoding
The degree encoder is a learnable embedding layer that encodes the degree of each node into a vector. It takes as input the batched input and output degrees of graph nodes, and outputs the degree embeddings of the nodes.

.. code:: python
degree_encoder = dgl.nn.DegreeEncoder(
max_degree=8, # the maximum degree to cut off
embedding_dim=512 # the dimension of the degree embedding
Expand All @@ -22,6 +23,7 @@ Path Encoding
The path encoder encodes the edge features on the shortest path between two nodes to get attention bias for the self-attention module. It takes as input the batched edge features in shape and outputs the attention bias based on path encoding.

.. code:: python
path_encoder = PathEncoder(
max_len=5, # the maximum length of the shortest path
feat_dim=512, # the dimension of the edge feature
Expand All @@ -33,6 +35,7 @@ Spatial Encoding
The spatial encoder encodes the shortest distance between two nodes to get attention bias for the self-attention module. It takes as input the shortest distance between two nodes and outputs the attention bias based on spatial encoding.

.. code:: python
spatial_encoder = SpatialEncoder(
max_dist=5, # the maximum distance between two nodes
num_heads=8, # the number of attention heads
Expand All @@ -46,6 +49,7 @@ The Graphormer layer is like a Transformer encoder layer with the Multi-head Att
We can stack multiple Graphormer layers as a list just like implementing a Transformer encoder in PyTorch.

.. code:: python
layers = th.nn.ModuleList([
GraphormerLayer(
feat_size=512, # the dimension of the input node features
Expand All @@ -63,6 +67,7 @@ Model Forward
Grouping the modules above defines the primary components of the Graphormer model. We then can define the forward process as follows:

.. code:: python
node_feat, in_degree, out_degree, attn_mask, path_data, dist = \
next(iter(dataloader)) # we will use the first batch as an example
num_graphs, max_num_nodes, _ = node_feat.shape
Expand All @@ -84,6 +89,6 @@ Grouping the modules above defines the primary components of the Graphormer mode
attn_bias=attn_bias,
)
For simplicity, we omit some details in the forward process. For the complete implementation, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer`_.
For simplicity, we omit some details in the forward process. For the complete implementation, please refer to the `Graphormer example <https://github.com/dmlc/dgl/tree/master/examples/core/Graphormer>`_.

You can also explore other `utility modules <https://docs.dgl.ai/api/python/nn-pytorch.html#utility-modules-for-graph-transformer>`_ to customize your own graph transformer model. In the next section, we will show how to prepare the data for training.
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ The ``graph`` field is used to specify the graph structure. It has two fields:
homogeneous graphs. For heterogeneous graphs, it is the edge type.
- ``format``: ``string``

The ``format`` field is used to specify the format of the edge data. It can
only be ``csv`` for now.
The ``format`` field is used to specify the format of the edge data. It
can be ``csv`` or ``numpy``. If it is ``csv``, no ``index`` and ``header``
fields are needed. If it is ``numpy``, the array requires to be in shape
of ``(2, num_edges)``. ``numpy`` format is recommended for large graphs.
- ``path``: ``string``

The ``path`` field is used to specify the path of the edge data. It is
Expand Down
39 changes: 28 additions & 11 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
shuffle=is_train,
)

############################################################################
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device. Copying here
# ensures that the rest of the operations run on the GPU.
############################################################################
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=args.device)

############################################################################
# [Input]:
# 'args.neg_ratio': Specify the ratio of negative to positive samples.
Expand Down Expand Up @@ -216,7 +226,8 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=args.device)
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device=args.device)

############################################################################
# [Input]:
Expand Down Expand Up @@ -304,11 +315,11 @@ def train(args, model, graph, features, train_set):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
dataloader = create_dataloader(args, graph, features, train_set)

for epoch in tqdm.trange(args.epochs):
for epoch in range(args.epochs):
model.train()
total_loss = 0
start_epoch_time = time.time()
for step, data in enumerate(dataloader):
for step, data in tqdm.tqdm(enumerate(dataloader)):
# Get node pairs with labels for loss calculation.
compacted_pairs, labels = data.node_pairs_with_labels

Expand Down Expand Up @@ -366,24 +377,30 @@ def parse_args():
help="Whether to exclude reverse edges during sampling. Default: 1",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
"--mode",
default="pinned-cuda",
choices=["cpu-cpu", "cpu-cuda", "pinned-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
return parser.parse_args()


def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
args.mode = "cpu-cpu"
print(f"Training in {args.mode} mode.")
args.storage_device, args.device = args.mode.split("-")
args.device = torch.device(args.device)

# Load and preprocess dataset.
print("Loading data")
dataset = gb.BuiltinDataset("ogbl-citation2").load()
graph = dataset.graph
features = dataset.feature

# Move the dataset to the selected storage.
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)

train_set = dataset.tasks[0].train_set
args.fanout = list(map(int, args.fanout.split(",")))

Expand Down
48 changes: 32 additions & 16 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,19 @@ def create_dataloader(

############################################################################
# [Step-2]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# 'extra_attrs': The extra attributes to copy.
# [Output]:
# A CopyTo object to copy the data to the specified device. Copying here
# ensures that the rest of the operations run on the GPU.
############################################################################
if args.storage_device != "cpu":
datapipe = datapipe.copy_to(device=device, extra_attrs=["seed_nodes"])

############################################################################
# [Step-3]:
# self.sample_neighbor()
# [Input]:
# 'graph': The network topology for sampling.
Expand All @@ -109,7 +122,7 @@ def create_dataloader(
)

############################################################################
# [Step-3]:
# [Step-4]:
# self.fetch_feature()
# [Input]:
# 'features': The node features.
Expand All @@ -125,17 +138,18 @@ def create_dataloader(
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Step-4]:
# [Step-5]:
# self.copy_to()
# [Input]:
# 'device': The device to copy the data to.
# [Output]:
# A CopyTo object to copy the data to the specified device.
############################################################################
datapipe = datapipe.copy_to(device=device)
if args.storage_device == "cpu":
datapipe = datapipe.copy_to(device=device)

############################################################################
# [Step-5]:
# [Step-6]:
# gb.DataLoader()
# [Input]:
# 'datapipe': The datapipe object to be used for data loading.
Expand Down Expand Up @@ -259,7 +273,7 @@ def evaluate(args, model, graph, features, itemset, num_classes):
job="evaluate",
)

for step, data in tqdm(enumerate(dataloader)):
for step, data in tqdm(enumerate(dataloader), "Evaluating"):
x = data.node_features["feat"]
y.append(data.labels)
y_hats.append(model(data.blocks, x))
Expand Down Expand Up @@ -289,7 +303,7 @@ def train(args, graph, features, train_set, valid_set, num_classes, model):
t0 = time.time()
model.train()
total_loss = 0
for step, data in enumerate(dataloader):
for step, data in tqdm(enumerate(dataloader), "Training"):
# The input features from the source nodes in the first layer's
# computation graph.
x = data.node_features["feat"]
Expand Down Expand Up @@ -349,28 +363,30 @@ def parse_args():
" identical with the number of layers in your model. Default: 10,10,10",
)
parser.add_argument(
"--device",
default="cpu",
choices=["cpu", "cuda"],
help="Train device: 'cpu' for CPU, 'cuda' for GPU.",
"--mode",
default="pinned-cuda",
choices=["cpu-cpu", "cpu-cuda", "pinned-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
return parser.parse_args()


def main(args):
if not torch.cuda.is_available():
args.device = "cpu"
print(f"Training in {args.device} mode.")
args.mode = "cpu-cpu"
print(f"Training in {args.mode} mode.")
args.storage_device, args.device = args.mode.split("-")
args.device = torch.device(args.device)

# Load and preprocess dataset.
print("Loading data...")
dataset = gb.BuiltinDataset("ogbn-products").load()

graph = dataset.graph
# Currently the neighbor-sampling process can only be done on the CPU,
# therefore there is no need to copy the graph to the GPU.
features = dataset.feature
# Move the dataset to the selected storage.
graph = dataset.graph.to(args.storage_device)
features = dataset.feature.to(args.storage_device)

train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
Expand Down
14 changes: 10 additions & 4 deletions examples/sampling/graphbolt/quickstart/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,17 @@
############################################################################
# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.
############################################################################
def create_dataloader(dateset, device, is_train=True):
def create_dataloader(dataset, device, is_train=True):
# The second of two tasks in the dataset is link prediction.
task = dataset.tasks[1]
itemset = task.train_set if is_train else task.test_set

# Sample seed edges from the itemset.
datapipe = gb.ItemSampler(itemset, batch_size=256)

# Copy the mini-batch to the designated device for sampling and training.
datapipe = datapipe.copy_to(device)

if is_train:
# Sample negative edges for the seed edges.
datapipe = datapipe.sample_uniform_negative(
Expand All @@ -47,9 +50,6 @@ def create_dataloader(dateset, device, is_train=True):
dataset.feature, node_feature_keys=["feat"]
)

# Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device)

# Initiate the dataloader for the datapipe.
return gb.DataLoader(datapipe)

Expand Down Expand Up @@ -158,6 +158,12 @@ def train(model, dataset, device):
print("Loading data...")
dataset = gb.BuiltinDataset("cora").load()

# If a CUDA device is selected, we pin the graph and the features so that
# the GPU can access them.
if device == torch.device("cuda:0"):
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

in_size = dataset.feature.size("node", None, "feat")[0]
model = GraphSAGE(in_size).to(device)

Expand Down
14 changes: 10 additions & 4 deletions examples/sampling/graphbolt/quickstart/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
############################################################################
# (HIGHLIGHT) Create a single process dataloader with dgl graphbolt package.
############################################################################
def create_dataloader(dateset, itemset, device):
def create_dataloader(dataset, itemset, device):
# Sample seed nodes from the itemset.
datapipe = gb.ItemSampler(itemset, batch_size=16)

# Copy the mini-batch to the designated device for sampling and training.
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])

# Sample neighbors for the seed nodes.
datapipe = datapipe.sample_neighbor(dataset.graph, fanouts=[4, 2])

Expand All @@ -25,9 +28,6 @@ def create_dataloader(dateset, itemset, device):
dataset.feature, node_feature_keys=["feat"]
)

# Copy the mini-batch to the designated device for training.
datapipe = datapipe.copy_to(device)

# Initiate the dataloader for the datapipe.
return gb.DataLoader(datapipe)

Expand Down Expand Up @@ -119,6 +119,12 @@ def train(model, dataset, device):
print("Loading data...")
dataset = gb.BuiltinDataset("cora").load()

# If a CUDA device is selected, we pin the graph and the features so that
# the GPU can access them.
if device == torch.device("cuda:0"):
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

in_size = dataset.feature.size("node", None, "feat")[0]
out_size = dataset.tasks[0].metadata["num_classes"]
model = GCN(in_size, out_size).to(device)
Expand Down
6 changes: 4 additions & 2 deletions examples/sampling/graphbolt/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,14 +430,15 @@ def evaluate(
else:
evaluator = MAG240MEvaluator()

num_etype = len(g.num_edges)
data_loader = create_dataloader(
name,
g,
features,
item_set,
device,
batch_size=4096,
fanouts=[25, 10],
fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=False,
num_workers=num_workers,
)
Expand Down Expand Up @@ -491,14 +492,15 @@ def train(
print("Start to train...")
category = "paper"

num_etype = len(g.num_edges)
data_loader = create_dataloader(
name,
g,
features,
train_set,
device,
batch_size=1024,
fanouts=[25, 10],
fanouts=[torch.full((num_etype,), 25), torch.full((num_etype,), 10)],
shuffle=True,
num_workers=num_workers,
)
Expand Down

0 comments on commit 49ee714

Please sign in to comment.