Skip to content

Commit

Permalink
Merge branch 'master' into untyped_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Feb 9, 2024
2 parents c1f8acd + 8e6cbd6 commit bb9d2da
Show file tree
Hide file tree
Showing 36 changed files with 3,188 additions and 455 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.graphbolt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Utilities
etype_tuple_to_str
isin
seed
index_select
expand_indptr
add_reverse_edges
exclude_seed_edges
Expand Down
4 changes: 4 additions & 0 deletions docs/source/guide/minibatch-node.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ putting the list of generated MFGs onto GPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-arxiv").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
Expand Down Expand Up @@ -205,6 +207,8 @@ of node types to node IDs.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset = gb.BuiltinDataset("ogbn-mag").load()
g = dataset.graph
feature = dataset.feature
train_set = dataset.tasks[0].train_set
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
Expand Down
6 changes: 4 additions & 2 deletions docs/source/install/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ System requirements
-------------------
DGL works with the following operating systems:

* Ubuntu 16.04
* Ubuntu 20.04+
* CentOS 8+
* RHEL 8+
* macOS X
* Windows 10

DGL requires Python version 3.6, 3.7, 3.8 or 3.9.
DGL requires Python version 3.7, 3.8, 3.9, 3.10, 3.11.

DGL supports multiple tensor libraries as backends, e.g., PyTorch, MXNet. For requirements on backends and how to select one, see :ref:`backends`.

Expand Down
43 changes: 28 additions & 15 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ def evaluate(rank, model, dataloader, num_classes, device):
y = []
y_hats = []

for step, data in (
tqdm.tqdm(enumerate(dataloader)) if rank == 0 else enumerate(dataloader)
):
for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader:
blocks = data.blocks
x = data.node_features["feat"]
y.append(data.labels)
Expand Down Expand Up @@ -271,44 +269,53 @@ def run(rank, world_size, args, devices, dataset):

# Pin the graph and features to enable GPU access.
if args.storage_device == "pinned":
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()
graph = dataset.graph.pin_memory_()
feature = dataset.feature.pin_memory_()
else:
graph = dataset.graph.to(args.storage_device)
feature = dataset.feature.to(args.storage_device)

train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]

in_size = dataset.feature.size("node", None, "feat")[0]
in_size = feature.size("node", None, "feat")[0]
hidden_size = 256
out_size = num_classes

if args.gpu_cache_size > 0 and args.storage_device != "cuda":
feature._features[("node", None, "feat")] = gb.GPUCachedFeature(
feature._features[("node", None, "feat")],
args.gpu_cache_size,
)

# Create GraphSAGE model. It should be copied onto a GPU as a replica.
model = SAGE(in_size, hidden_size, out_size).to(device)
model = DDP(model)

# Create data loaders.
train_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
graph,
feature,
train_set,
device,
is_train=True,
)
valid_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
graph,
feature,
valid_set,
device,
is_train=False,
)
test_dataloader = create_dataloader(
args,
dataset.graph,
dataset.feature,
graph,
feature,
test_set,
device,
is_train=False,
Expand Down Expand Up @@ -381,12 +388,18 @@ def parse_args():
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--gpu-cache-size",
type=int,
default=0,
help="The capacity of the GPU cache, the number of features to store.",
)
parser.add_argument(
"--mode",
default="pinned-cuda",
choices=["cpu-cuda", "pinned-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM,"
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
choices=["cpu-cuda", "pinned-cuda", "cuda-cuda"],
help="Dataset storage placement and Train device: 'cpu' for CPU and RAM"
", 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
return parser.parse_args()

Expand Down
38 changes: 16 additions & 22 deletions examples/sampling/graphbolt/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,10 @@ def forward(self, blocks, x):
hidden_x = F.relu(hidden_x)
return hidden_x

def inference(self, graph, features, dataloader, device):
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")

buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)

print("Start node embedding inference.")
for layer_idx, layer in enumerate(self.layers):
Expand All @@ -99,17 +95,17 @@ def inference(self, graph, features, dataloader, device):
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)
for step, data in tqdm.tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
for data in tqdm.tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
# By design, our seed nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device, non_blocking=True
)
feature = y
if not is_last_layer:
features.update("node", None, "feat", y)

return y

Expand Down Expand Up @@ -185,7 +181,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# [Role]:
# Initialize a neighbor sampler for sampling the neighborhoods of nodes.
############################################################################
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.sample_neighbor(
graph, args.fanout if is_train else [-1]
)

############################################################################
# [Input]:
Expand Down Expand Up @@ -213,12 +211,9 @@ def create_dataloader(args, graph, features, itemset, is_train=True):
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in evaluation/inference because features
# are updated as a whole during it, thus storing features in minibatch is
# unnecessary.
# subgraphs.
############################################################################
if is_train:
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Input]:
Expand Down Expand Up @@ -286,15 +281,12 @@ def evaluate(args, model, graph, features, all_nodes_set, valid_set, test_set):
model.eval()
evaluator = Evaluator(name="ogbl-citation2")

# Since we need to use all neghborhoods for evaluation, we set the fanout
# to -1.
args.fanout = [-1]
dataloader = create_dataloader(
args, graph, features, all_nodes_set, is_train=False
)

# Compute node embeddings for the entire graph.
node_emb = model.inference(graph, features, dataloader, args.device)
node_emb = model.inference(graph, features, dataloader, args.storage_device)
results = []

# Loop over both validation and test sets.
Expand Down Expand Up @@ -340,6 +332,8 @@ def train(args, model, graph, features, train_set):

total_loss += loss.item()
if step + 1 == args.early_stop:
# Early stopping requires a new dataloader to reset its state.
dataloader = create_dataloader(args, graph, features, train_set)
break

end_epoch_time = time.time()
Expand Down
29 changes: 11 additions & 18 deletions examples/sampling/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ def create_dataloader(
# A FeatureFetcher object to fetch node features.
# [Role]:
# Initialize a feature fetcher for fetching features of the sampled
# subgraphs. This step is skipped in inference because features are updated
# as a whole during it, thus storing features in minibatch is unnecessary.
# subgraphs.
############################################################################
if job != "infer":
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Step-5]:
Expand Down Expand Up @@ -194,14 +192,10 @@ def forward(self, blocks, x):
hidden_x = self.dropout(hidden_x)
return hidden_x

def inference(self, graph, features, dataloader, device):
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
feature = features.read("node", None, "feat")

buffer_device = torch.device("cpu")
# Enable pin_memory for faster CPU to GPU data transfer if the
# model is running on a GPU.
pin_memory = buffer_device != device
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)

for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
Expand All @@ -213,19 +207,18 @@ def inference(self, graph, features, dataloader, device):
device=buffer_device,
pin_memory=pin_memory,
)
feature = feature.to(device)

for step, data in tqdm(enumerate(dataloader)):
x = feature[data.input_nodes]
hidden_x = layer(data.blocks[0], x) # len(blocks) = 1
for data in tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.seed_nodes[0] : data.seed_nodes[-1] + 1] = hidden_x.to(
buffer_device
)
feature = y
if not is_last_layer:
features.update("node", None, "feat", y)

return y

Expand All @@ -245,7 +238,7 @@ def layerwise_infer(
num_workers=args.num_workers,
job="infer",
)
pred = model.inference(graph, features, dataloader, args.device)
pred = model.inference(graph, features, dataloader, args.storage_device)
pred = pred[test_set._items[0]]
label = test_set._items[1].to(pred.device)

Expand Down
12 changes: 6 additions & 6 deletions graphbolt/src/cuda/gpu_cache.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,19 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
auto allocator = cuda::GetAllocator();
auto missing_len_device = allocator.AllocateStorage<size_t>(1);
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
reinterpret_cast<key_t *>(missing_keys.data_ptr()),
missing_len_device.get(), cuda::GetCurrentStream());
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
cuda::CopyScalar<size_t> missing_len(missing_len_device.get());
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys);
Expand All @@ -79,6 +78,7 @@ void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
"Values should have the correct dimensions.");
TORCH_CHECK(
values.scalar_type() == dtype_, "Values should have the correct dtype.");
if (keys.numel() == 0) return;
keys = keys.to(torch::kLong);
torch::Tensor float_values;
if (num_bytes_ % sizeof(float) != 0) {
Expand Down
8 changes: 6 additions & 2 deletions graphbolt/src/cuda/index_select_csc_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
#include <numeric>

#include "./common.h"
#include "./max_uva_threads.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

constexpr int BLOCK_SIZE = 128;
constexpr int BLOCK_SIZE = CUDA_MAX_NUM_THREADS;

// Given the in_degree array and a permutation, returns in_degree of the output
// and the permuted and modified in_degree of the input. The modified in_degree
Expand Down Expand Up @@ -130,7 +131,10 @@ std::tuple<torch::Tensor, torch::Tensor> UVAIndexSelectCSCCopyIndices(
torch::Tensor output_indices =
torch::empty(output_size.value(), options.dtype(indices.scalar_type()));
const dim3 block(BLOCK_SIZE);
const dim3 grid((edge_count_aligned + BLOCK_SIZE - 1) / BLOCK_SIZE);
const dim3 grid(
(std::min(edge_count_aligned, cuda::max_uva_threads.value_or(1 << 20)) +
BLOCK_SIZE - 1) /
BLOCK_SIZE);

// Find the smallest integer type to store the coo_aligned_rows tensor.
const int num_bits = cuda::NumberOfBits(num_nodes);
Expand Down
2 changes: 1 addition & 1 deletion graphbolt/src/cuda/index_select_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr,
input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else {
constexpr int BLOCK_SIZE = 512;
constexpr int BLOCK_SIZE = CUDA_MAX_NUM_THREADS;
dim3 block(BLOCK_SIZE, 1);
while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {
block.x >>= 1;
Expand Down

0 comments on commit bb9d2da

Please sign in to comment.