Skip to content

Commit

Permalink
Merge branch 'master' into gb_cuda_examples2
Browse files Browse the repository at this point in the history
  • Loading branch information
mfbalin committed Jan 18, 2024
2 parents 005ff4f + 78fa316 commit 36e6c37
Show file tree
Hide file tree
Showing 26 changed files with 890 additions and 85 deletions.
9 changes: 9 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,10 @@ if(BUILD_GRAPHBOLT)
string(REPLACE ";" "\\;" CUDA_ARCHITECTURES_ESCAPED "${CUDA_ARCHITECTURES}")
file(TO_NATIVE_PATH ${CMAKE_CURRENT_BINARY_DIR} BINDIR)
file(TO_NATIVE_PATH ${CMAKE_COMMAND} CMAKE_CMD)
if(USE_CUDA)
get_target_property(GPU_CACHE_INCLUDE_DIRS gpu_cache INCLUDE_DIRECTORIES)
endif(USE_CUDA)
string(REPLACE ";" "\\;" GPU_CACHE_INCLUDE_DIRS_ESCAPED "${GPU_CACHE_INCLUDE_DIRS}")
if(MSVC)
file(TO_NATIVE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/graphbolt/build.bat BUILD_SCRIPT)
add_custom_target(
Expand All @@ -540,6 +544,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${BINDIR}
GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}"
CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS}
CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}"
Expand All @@ -557,6 +562,7 @@ if(BUILD_GRAPHBOLT)
CUDA_TOOLKIT_ROOT_DIR=${CUDA_TOOLKIT_ROOT_DIR}
USE_CUDA=${USE_CUDA}
BINDIR=${CMAKE_CURRENT_BINARY_DIR}
GPU_CACHE_INCLUDE_DIRS="${GPU_CACHE_INCLUDE_DIRS_ESCAPED}"
CFLAGS=${CMAKE_C_FLAGS}
CXXFLAGS=${CMAKE_CXX_FLAGS}
CUDAARCHS="${CUDA_ARCHITECTURES_ESCAPED}"
Expand All @@ -565,4 +571,7 @@ if(BUILD_GRAPHBOLT)
DEPENDS ${BUILD_SCRIPT}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/graphbolt)
endif(MSVC)
if(USE_CUDA)
add_dependencies(graphbolt gpu_cache)
endif(USE_CUDA)
endif(BUILD_GRAPHBOLT)
39 changes: 25 additions & 14 deletions examples/multigpu/graphbolt/node_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,6 @@ def create_dataloader(
shuffle=shuffle,
drop_uneven_inputs=drop_uneven_inputs,
)
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])

############################################################################
# [Note]:
# datapipe.copy_to() / gb.CopyTo()
Expand All @@ -137,8 +134,14 @@ def create_dataloader(
# [Output]:
# A CopyTo object copying data in the datapipe to a specified device.\
############################################################################
datapipe = datapipe.copy_to(device)
dataloader = gb.DataLoader(datapipe, num_workers=args.num_workers)
if not args.cpu_sampling:
datapipe = datapipe.copy_to(device, extra_attrs=["seed_nodes"])
datapipe = datapipe.sample_neighbor(graph, args.fanout)
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
if args.cpu_sampling:
datapipe = datapipe.copy_to(device)

dataloader = gb.DataLoader(datapipe, args.num_workers)

# Return the fully-initialized DataLoader object.
return dataloader
Expand Down Expand Up @@ -272,15 +275,18 @@ def run(rank, world_size, args, devices, dataset):
rank=rank,
)

graph = dataset.graph
features = dataset.feature
# Pin the graph and features to enable GPU access.
if not args.cpu_sampling:
dataset.graph.pin_memory_()
dataset.feature.pin_memory_()

train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]

in_size = features.size("node", None, "feat")[0]
in_size = dataset.feature.size("node", None, "feat")[0]
hidden_size = 256
out_size = num_classes

Expand All @@ -291,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
# Create data loaders.
train_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
train_set,
device,
drop_last=False,
Expand All @@ -301,8 +307,8 @@ def run(rank, world_size, args, devices, dataset):
)
valid_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
valid_set,
device,
drop_last=False,
Expand All @@ -311,8 +317,8 @@ def run(rank, world_size, args, devices, dataset):
)
test_dataloader = create_dataloader(
args,
graph,
features,
dataset.graph,
dataset.feature,
test_set,
device,
drop_last=False,
Expand Down Expand Up @@ -387,6 +393,11 @@ def parse_args():
parser.add_argument(
"--num-workers", type=int, default=0, help="The number of processes."
)
parser.add_argument(
"--cpu-sampling",
action="store_true",
help="Disables GPU sampling and utilizes the CPU for dataloading.",
)
return parser.parse_args()


Expand Down
6 changes: 6 additions & 0 deletions graphbolt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ file(GLOB BOLT_SRC ${BOLT_DIR}/*.cc)
if(USE_CUDA)
file(GLOB BOLT_CUDA_SRC
${BOLT_DIR}/cuda/*.cu
${BOLT_DIR}/cuda/*.cc
)
list(APPEND BOLT_SRC ${BOLT_CUDA_SRC})
if(DEFINED ENV{CUDAARCHS})
Expand All @@ -75,6 +76,11 @@ if(USE_CUDA)
"../third_party/cccl/thrust"
"../third_party/cccl/cub"
"../third_party/cccl/libcudacxx/include")

message(STATUS "Use HugeCTR gpu_cache for graphbolt with INCLUDE_DIRS $ENV{GPU_CACHE_INCLUDE_DIRS}.")
target_include_directories(${LIB_GRAPHBOLT_NAME} PRIVATE $ENV{GPU_CACHE_INCLUDE_DIRS})
target_link_directories(${LIB_GRAPHBOLT_NAME} PRIVATE ${GPU_CACHE_BUILD_DIR})
target_link_libraries(${LIB_GRAPHBOLT_NAME} gpu_cache)

get_property(archs TARGET ${LIB_GRAPHBOLT_NAME} PROPERTY CUDA_ARCHITECTURES)
message(STATUS "CUDA_ARCHITECTURES for graphbolt: ${archs}")
Expand Down
4 changes: 2 additions & 2 deletions graphbolt/build.bat
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ IF x%1x == xx GOTO single

FOR %%X IN (%*) DO (
DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
"%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release -DPYTHON_INTERP=%%X .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild graphbolt.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1
)
Expand All @@ -21,7 +21,7 @@ GOTO end
:single

DEL /S /Q *
"%CMAKE_COMMAND%" -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1
"%CMAKE_COMMAND%" -DGPU_CACHE_BUILD_DIR=%BINDIR% -DCMAKE_CONFIGURATION_TYPES=Release .. -G "Visual Studio 16 2019" || EXIT /B 1
msbuild graphbolt.sln /m /nr:false || EXIT /B 1
COPY /Y Release\*.dll "%BINDIR%\graphbolt" || EXIT /B 1

Expand Down
2 changes: 1 addition & 1 deletion graphbolt/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ else
CPSOURCE=*.so
fi

CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA"
CMAKE_FLAGS="-DCUDA_TOOLKIT_ROOT_DIR=$CUDA_TOOLKIT_ROOT_DIR -DUSE_CUDA=$USE_CUDA -DGPU_CACHE_BUILD_DIR=$BINDIR"
echo $CMAKE_FLAGS

if [ $# -eq 0 ]; then
Expand Down
108 changes: 108 additions & 0 deletions graphbolt/src/cuda/gpu_cache.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gpu_cache.cu
* @brief GPUCache implementation on CUDA.
*/
#include <numeric>

#include "./common.h"
#include "./gpu_cache.h"

namespace graphbolt {
namespace cuda {

GpuCache::GpuCache(const std::vector<int64_t> &shape, torch::ScalarType dtype) {
TORCH_CHECK(shape.size() >= 2, "Shape must at least have 2 dimensions.");
const auto num_items = shape[0];
const int64_t num_feats =
std::accumulate(shape.begin() + 1, shape.end(), 1ll, std::multiplies<>());
const int element_size =
torch::empty(1, torch::TensorOptions().dtype(dtype)).element_size();
num_bytes_ = num_feats * element_size;
num_float_feats_ = (num_bytes_ + sizeof(float) - 1) / sizeof(float);
cache_ = std::make_unique<gpu_cache_t>(
(num_items + bucket_size - 1) / bucket_size, num_float_feats_);
shape_ = shape;
shape_[0] = -1;
dtype_ = dtype;
device_id_ = cuda::GetCurrentStream().device_index();
}

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> GpuCache::Query(
torch::Tensor keys) {
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
keys.device().index() == device_id_,
"Keys should be on the correct CUDA device.");
TORCH_CHECK(keys.sizes().size() == 1, "Keys should be a 1D tensor.");
keys = keys.to(torch::kLong);
auto values = torch::empty(
{keys.size(0), num_float_feats_}, keys.options().dtype(torch::kFloat));
auto missing_index =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
auto missing_keys =
torch::empty(keys.size(0), keys.options().dtype(torch::kLong));
cuda::CopyScalar<size_t> missing_len;
auto stream = cuda::GetCurrentStream();
cache_->Query(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
values.data_ptr<float>(),
reinterpret_cast<uint64_t *>(missing_index.data_ptr()),
reinterpret_cast<key_t *>(missing_keys.data_ptr()), missing_len.get(),
stream);
values = values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.view(dtype_)
.view(shape_);
// To safely read missing_len, we synchronize
stream.synchronize();
missing_index = missing_index.slice(0, 0, static_cast<size_t>(missing_len));
missing_keys = missing_keys.slice(0, 0, static_cast<size_t>(missing_len));
return std::make_tuple(values, missing_index, missing_keys);
}

void GpuCache::Replace(torch::Tensor keys, torch::Tensor values) {
TORCH_CHECK(keys.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
keys.device().index() == device_id_,
"Keys should be on the correct CUDA device.");
TORCH_CHECK(values.device().is_cuda(), "Keys should be on a CUDA device.");
TORCH_CHECK(
values.device().index() == device_id_,
"Values should be on the correct CUDA device.");
TORCH_CHECK(
keys.size(0) == values.size(0),
"The first dimensions of keys and values must match.");
TORCH_CHECK(
std::equal(shape_.begin() + 1, shape_.end(), values.sizes().begin() + 1),
"Values should have the correct dimensions.");
TORCH_CHECK(
values.scalar_type() == dtype_, "Values should have the correct dtype.");
keys = keys.to(torch::kLong);
torch::Tensor float_values;
if (num_bytes_ % sizeof(float) != 0) {
float_values = torch::empty(
{values.size(0), num_float_feats_},
values.options().dtype(torch::kFloat));
float_values.view(torch::kByte)
.slice(1, 0, num_bytes_)
.copy_(values.view(torch::kByte).view({values.size(0), -1}));
} else {
float_values = values.view(torch::kByte)
.view({values.size(0), -1})
.view(torch::kFloat)
.contiguous();
}
cache_->Replace(
reinterpret_cast<const key_t *>(keys.data_ptr()), keys.size(0),
float_values.data_ptr<float>(), cuda::GetCurrentStream());
}

c10::intrusive_ptr<GpuCache> GpuCache::Create(
const std::vector<int64_t> &shape, torch::ScalarType dtype) {
return c10::make_intrusive<GpuCache>(shape, dtype);
}

} // namespace cuda
} // namespace graphbolt
66 changes: 66 additions & 0 deletions graphbolt/src/cuda/gpu_cache.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/**
* Copyright (c) 2023 by Contributors
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* @file cuda/gpu_cache.h
* @brief Header file of HugeCTR gpu_cache wrapper.
*/

#ifndef GRAPHBOLT_GPU_CACHE_H_
#define GRAPHBOLT_GPU_CACHE_H_

#include <torch/custom_class.h>
#include <torch/torch.h>

#include <limits>
#include <nv_gpu_cache.hpp>

namespace graphbolt {
namespace cuda {

class GpuCache : public torch::CustomClassHolder {
using key_t = long long;
constexpr static int set_associativity = 2;
constexpr static int WARP_SIZE = 32;
constexpr static int bucket_size = WARP_SIZE * set_associativity;
using gpu_cache_t = ::gpu_cache::gpu_cache<
key_t, uint64_t, std::numeric_limits<key_t>::max(), set_associativity,
WARP_SIZE>;

public:
/**
* @brief Constructor for the GpuCache struct.
*
* @param shape The shape of the GPU cache.
* @param dtype The datatype of items to be stored.
*/
GpuCache(const std::vector<int64_t>& shape, torch::ScalarType dtype);

GpuCache() = default;

std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> Query(
torch::Tensor keys);

void Replace(torch::Tensor keys, torch::Tensor values);

static c10::intrusive_ptr<GpuCache> Create(
const std::vector<int64_t>& shape, torch::ScalarType dtype);

private:
std::vector<int64_t> shape_;
torch::ScalarType dtype_;
std::unique_ptr<gpu_cache_t> cache_;
int64_t num_bytes_;
int64_t num_float_feats_;
torch::DeviceIndex device_id_;
};

// The cu file in HugeCTR gpu cache uses unsigned int and long long.
// Changing to int64_t results in a mismatch of template arguments.
static_assert(
sizeof(long long) == sizeof(int64_t),
"long long and int64_t needs to have the same size."); // NOLINT

} // namespace cuda
} // namespace graphbolt

#endif // GRAPHBOLT_GPU_CACHE_H_
13 changes: 10 additions & 3 deletions graphbolt/src/cuda/index_select_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <numeric>

#include "./common.h"
#include "./max_uva_threads.h"
#include "./utils.h"

namespace graphbolt {
Expand Down Expand Up @@ -122,17 +123,23 @@ torch::Tensor UVAIndexSelectImpl_(torch::Tensor input, torch::Tensor index) {
if (aligned_feature_size == 1) {
// Use a single thread to process each output row to avoid wasting threads.
const int num_threads = cuda::FindNumThreads(return_len);
const int num_blocks = (return_len + num_threads - 1) / num_threads;
const int num_blocks =
(std::min(return_len, cuda::max_uva_threads.value_or(1 << 20)) +
num_threads - 1) /
num_threads;
CUDA_KERNEL_CALL(
IndexSelectSingleKernel, num_blocks, num_threads, 0, input_ptr,
input_len, index_sorted_ptr, return_len, ret_ptr, permutation_ptr);
} else {
dim3 block(512, 1);
constexpr int BLOCK_SIZE = 512;
dim3 block(BLOCK_SIZE, 1);
while (static_cast<int64_t>(block.x) >= 2 * aligned_feature_size) {
block.x >>= 1;
block.y <<= 1;
}
const dim3 grid((return_len + block.y - 1) / block.y);
const dim3 grid(std::min(
(return_len + block.y - 1) / block.y,
cuda::max_uva_threads.value_or(1 << 20) / BLOCK_SIZE));
if (aligned_feature_size * sizeof(DType) <= GPU_CACHE_LINE_SIZE) {
// When feature size is smaller than GPU cache line size, use unaligned
// version for less SM usage, which is more resource efficient.
Expand Down

0 comments on commit 36e6c37

Please sign in to comment.