Skip to content

Commit

Permalink
Merge branch 'master' into untyped_storage
Browse files Browse the repository at this point in the history
  • Loading branch information
drivanov committed Jan 19, 2024
2 parents 4283dcb + 2e6ded0 commit 6a74b84
Show file tree
Hide file tree
Showing 41 changed files with 1,192 additions and 197 deletions.
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT)
string(REPLACE ";" "\\;" CUDA_ARCHITECTURES_ESCAPED "${CUDA_ARCHITECTURES}")
file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)
file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)
if(USE_CUDA)
get_target_property(GPU_CACHE_INCLUDE_DIRS gpu_cache INCLUDE_DIRECTORIES)
endif(USE_CUDA)
string(REPLACE ";" "\\;" GPU_CACHE_INCLUDE_DIRS_ESCAPED "${GPU_CACHE_INCLUDE_DIRS}")
if(MSVC)
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT)
add_custom_target(
Expand All @@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${BINDIR}
GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}"
CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS}
CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}"
Expand All @@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${CMAKE_CURRENT_BINARY_DIR}
GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}"
CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS}
CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}"
Expand All @@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT)
DEPENDS ${BUILD_SCRIPT}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)
endif(MSVC)
if(USE_CUDA)
add_dependencies(graphbolt gpu_cache)
endif(USE_CUDA)
endif(BUILD_GRAPHBOLT)
1 change: 1 addition & 0 deletions docs/source/api/python/dgl.graphbolt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ Utilities
etype_tuple_to_str
isin
seed
expand_indptr
add_reverse_edges
exclude_seed_edges
compact_csc_format
Expand Down
4 changes: 2 additions & 2 deletions examples/core/rgcn/hetero_rgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def prepare_data(args, device):
# Initialize a train sampler that samples neighbors for multi-layer graph
# convolution. It samples 25 and 10 neighbors for the first and second
# layers respectively.
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 10])
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 10], fused=False)
num_workers = args.num_workers
train_loader = dgl.dataloading.DataLoader(
g,
Expand Down Expand Up @@ -488,7 +488,7 @@ def evaluate(
else:
evaluator = MAG240MEvaluator()

sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 10])
sampler = dgl.dataloading.MultiLayerNeighborSampler([25, 10], fused=False)
dataloader = dgl.dataloading.DataLoader(
g,
idx,
Expand Down
39 changes: 25 additions & 14 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ def create_dataloader(
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
)
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Note]:
# datapipe.copy_to() / gb.CopyTo()
Expand All @@ -137,8 +134,14 @@ def create_dataloader(
# [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers)
if not args.cpu_sampling:
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.cpu_sampling:
datapipe = datapipe.copy_to(device)

dataloader = gb.DataLoader(datapipe, args.num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down Expand Up @@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset):
rank=rank,
)

graph = dataset.graph
features = dataset.feature
# Pin the graph and features to enable GPU access.
if not args.cpu_sampling:
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]

in_size = features.size("node", None, "feat")[0]
in_size = dataset.feature.size("node", None, "feat")[0]
hidden_size = 256
out_size = num_classes

Expand All @@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
# Create data loaders.
train_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
train_set,
device,
drop_last=False,
Expand All @@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset):
)
valid_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
valid_set,
device,
drop_last=False,
Expand All @@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset):
)
test_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
test_set,
device,
drop_last=False,
Expand Down Expand Up @@ -387,6 +393,11 @@ def parse_args():
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--cpu-sampling",
action="store_true",
help="Disables GPU sampling and utilizes the CPU for dataloading.",
)
return parser.parse_args()


Expand Down
6 changes: 6 additions & 0 deletions graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc)
if(USE_CUDA)
file(GLOB BOLT_CUDA_SRC
${BOLT_DIR}/cuda/*.cu
${BOLT_DIR}/cuda/*.cc
)
list(APPEND BOLT_SRC ${BOLT_CUDA_SRC})
if(DEFINED ENV{CUDAARCHS})
Expand All @@ -75,6 +76,11 @@ if(USE_CUDA)
"../third_party/cccl/thrust"
"../third_party/cccl/cub"
"../third_party/cccl/libcudacxx/include")

message(STATUS "Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}.")
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS})
target_link_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${GPU_CACHE_BUILD_DIR})
target_link_libraries(${LIB_GRAPHBOLT_NAME} gpu_cache)

get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES)
message(STATUS "CUDA_ARCHITECTURES for graphbolt: ${archs}")
Expand Down
4 changes: 2 additions & 2 deletions graphbolt/build.bat
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ IF x%1x == xx GOTO single

FOR %%X IN (%*) DO (
DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
"%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild graphbolt.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1
)
Expand All @@ -21,7 +21,7 @@ GOTO end
:single

DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1
"%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild graphbolt.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1

Expand Down
2 changes: 1 addition & 1 deletion graphbolt/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ else
CPSOURCE=*.so
fi

CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA"
CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA -DGPU_CACHE_BUILD_DIR=$BINDIR"
echo $CMAKE_FLAGS

if [ $# -eq 0 ]; then
Expand Down
24 changes: 17 additions & 7 deletions graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* @file graphbolt/cuda_ops.h
* @brief Available CUDA operations in Graphbolt.
*/
#ifndef GRAPHBOLT_CUDA_OPS_H_
#define GRAPHBOLT_CUDA_OPS_H_

#include <torch/script.h>

Expand Down Expand Up @@ -162,16 +164,22 @@ torch::Tensor ExclusiveCumSum(torch::Tensor input);
torch::Tensor UVAIndexSelectImpl(torch::Tensor input, torch::Tensor index);

/**
* @brief CSRToCOO implements conversion from a given indptr offset tensor to a
* COO format tensor including ids in [0, indptr.size(0) - 1).
* @brief ExpandIndptrImpl implements conversion from a given indptr offset
* tensor to a COO format tensor. If node_ids is not given, it is assumed to be
* equal to torch::arange(indptr.size(0) - 1, dtype=dtype).
*
* @param input A tensor containing IDs.
* @param output_dtype Dtype of output.
* @param indptr The indptr offset tensor.
* @param dtype The dtype of the returned output tensor.
* @param node_ids Optional 1D tensor represents the node ids.
* @param output_size Optional value of indptr[-1]. Passing it eliminates CPU
* GPU synchronization.
*
* @return
* - The resulting tensor with output_dtype.
* @return The resulting tensor.
*/
torch::Tensor CSRToCOO(torch::Tensor indptr, torch::ScalarType output_dtype);
torch::Tensor ExpandIndptrImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> node_ids = torch::nullopt,
torch::optional<int64_t> output_size = torch::nullopt);

/**
* @brief Removes duplicate elements from the concatenated 'unique_dst_ids' and
Expand Down Expand Up @@ -214,3 +222,5 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> UniqueAndCompact(

} // namespace ops
} // namespace graphbolt

#endif // GRAPHBOLT_CUDA_OPS_H_
4 changes: 4 additions & 0 deletions graphbolt/include/graphbolt/cuda_sampling_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
* @file graphbolt/cuda_sampling_ops.h
* @brief Available CUDA sampling operations in Graphbolt.
*/
#ifndef GRAPHBOLT_CUDA_SAMPLING_OPS_H_
#define GRAPHBOLT_CUDA_SAMPLING_OPS_H_

#include <graphbolt/fused_sampled_subgraph.h>
#include <torch/script.h>
Expand Down Expand Up @@ -65,3 +67,5 @@ c10::intrusive_ptr<sampling::FusedSampledSubgraph> InSubgraph(

} // namespace ops
} // namespace graphbolt

#endif // GRAPHBOLT_CUDA_SAMPLING_OPS_H_
80 changes: 0 additions & 80 deletions graphbolt/src/cuda/csr_to_coo.cu

This file was deleted.

0 comments on commit 6a74b84

Please sign in to comment.