From 9727f4afdecb900839b2aaa693b6c00414a2a9df Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 26 Feb 2024 17:56:58 +0800 Subject: [PATCH] Add CUDA iterator to tensor view. --- include/xgboost/linalg.h | 12 +++++++--- include/xgboost/span.h | 4 ++-- src/common/linalg_op.cuh | 42 +++++++++++++++++++++++++-------- tests/cpp/common/test_linalg.cu | 24 ++++++++++++++++++- 4 files changed, 66 insertions(+), 16 deletions(-) diff --git a/include/xgboost/linalg.h b/include/xgboost/linalg.h index 581b2f0804c9..f538adbcdab3 100644 --- a/include/xgboost/linalg.h +++ b/include/xgboost/linalg.h @@ -295,6 +295,9 @@ class TensorView { using ShapeT = std::size_t[kDim]; using StrideT = ShapeT; + using element_type = T; // NOLINT + using value_type = std::remove_cv_t; // NOLINT + private: StrideT stride_{1}; ShapeT shape_{0}; @@ -314,7 +317,7 @@ class TensorView { } template - LINALG_HD size_t MakeSliceDim(size_t new_shape[D], size_t new_stride[D], + LINALG_HD size_t MakeSliceDim(std::size_t new_shape[D], std::size_t new_stride[D], detail::RangeTag &&range) const { static_assert(new_dim < D); static_assert(old_dim < kDim); @@ -528,9 +531,10 @@ class TensorView { LINALG_HD auto Stride(size_t i) const { return stride_[i]; } /** - * \brief Number of items in the tensor. + * @brief Number of items in the tensor. */ [[nodiscard]] LINALG_HD std::size_t Size() const { return size_; } + [[nodiscard]] bool Empty() const { return Size() == 0; } /** * \brief Whether this is a contiguous array, both C and F contiguous returns true. */ @@ -865,7 +869,9 @@ class Tensor { auto HostView() { return this->View(DeviceOrd::CPU()); } auto HostView() const { return this->View(DeviceOrd::CPU()); } - [[nodiscard]] size_t Size() const { return data_.Size(); } + [[nodiscard]] std::size_t Size() const { return data_.Size(); } + [[nodiscard]] bool Empty() const { return Size() == 0; } + auto Shape() const { return common::Span{shape_}; } auto Shape(size_t i) const { return shape_[i]; } diff --git a/include/xgboost/span.h b/include/xgboost/span.h index be8640f73695..29ca76d3c116 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -701,10 +701,10 @@ class IterSpan { return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; } [[nodiscard]] XGBOOST_DEVICE constexpr iterator begin() const noexcept { // NOLINT - return {this, 0}; + return it_; } [[nodiscard]] XGBOOST_DEVICE constexpr iterator end() const noexcept { // NOLINT - return {this, size()}; + return it_ + size(); } }; } // namespace common diff --git a/src/common/linalg_op.cuh b/src/common/linalg_op.cuh index 9ef36598d9de..21fad2dc0b4a 100644 --- a/src/common/linalg_op.cuh +++ b/src/common/linalg_op.cuh @@ -13,15 +13,14 @@ #include "xgboost/context.h" // for Context #include "xgboost/linalg.h" // for TensorView -namespace xgboost { -namespace linalg { +namespace xgboost::linalg { namespace cuda_impl { // Use template specialization to dispatch, Windows + CUDA 11.8 doesn't support extended // lambda inside constexpr if template struct ElementWiseImpl { template - void operator()(linalg::TensorView t, Fn&& fn, cudaStream_t s) { + void operator()(TensorView t, Fn&& fn, cudaStream_t s) { static_assert(D > 1); dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) mutable { std::apply(fn, linalg::UnravelIndex(i, t.Shape())); @@ -32,36 +31,59 @@ struct ElementWiseImpl { template struct ElementWiseImpl { template - void operator()(linalg::TensorView t, Fn&& fn, cudaStream_t s) { + void operator()(TensorView t, Fn&& fn, cudaStream_t s) { dh::LaunchN(t.Size(), s, [=] __device__(std::size_t i) { fn(i); }); } }; template -void ElementWiseKernel(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { +void ElementWiseKernel(TensorView t, Fn&& fn, cudaStream_t s = nullptr) { dh::safe_cuda(cudaSetDevice(t.Device().ordinal)); cuda_impl::ElementWiseImpl{}(t, fn, s); } } // namespace cuda_impl template -void ElementWiseTransformDevice(linalg::TensorView t, Fn&& fn, cudaStream_t s = nullptr) { +void ElementWiseTransformDevice(TensorView t, Fn&& fn, cudaStream_t s = nullptr) { if (t.Contiguous()) { auto ptr = t.Values().data(); dh::LaunchN(t.Size(), s, [=] __device__(size_t i) { ptr[i] = fn(i, ptr[i]); }); } else { dh::LaunchN(t.Size(), s, [=] __device__(size_t i) mutable { - T& v = detail::Apply(t, linalg::UnravelIndex(i, t.Shape())); + T& v = detail::Apply(t, UnravelIndex(i, t.Shape())); v = fn(i, v); }); } } template -void ElementWiseKernel(Context const* ctx, linalg::TensorView t, Fn&& fn) { +void ElementWiseKernel(Context const* ctx, TensorView t, Fn&& fn) { ctx->IsCUDA() ? cuda_impl::ElementWiseKernel(t, fn) : ElementWiseKernelHost(t, ctx->Threads(), fn); } -} // namespace linalg -} // namespace xgboost + +namespace detail { +template +struct IterOp { + TensorView v; + XGBOOST_DEVICE T& operator()(std::size_t i) { + return detail::Apply(v, UnravelIndex(i, v.Shape())); + } +}; +} // namespace detail + +// naming: thrust begin +// returns a thrust iterator for a tensor view. +template +auto tcbegin(TensorView v) { // NOLINT + return dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), + detail::IterOp>, kDim>{v}); +} + +template +auto tcend(TensorView v) { // NOLINT + return tcbegin(v) + v.Size(); +} +} // namespace xgboost::linalg #endif // XGBOOST_COMMON_LINALG_OP_CUH_ diff --git a/tests/cpp/common/test_linalg.cu b/tests/cpp/common/test_linalg.cu index 5f8bab4a3cc4..bf217842b660 100644 --- a/tests/cpp/common/test_linalg.cu +++ b/tests/cpp/common/test_linalg.cu @@ -1,8 +1,11 @@ /** - * Copyright 2021-2023 by XGBoost Contributors + * Copyright 2021-2024, XGBoost Contributors */ #include +#include // for equal +#include // for sequence +#include "../../../src/common/cuda_context.cuh" #include "../../../src/common/linalg_op.cuh" #include "../helpers.h" #include "xgboost/context.h" @@ -85,4 +88,23 @@ void TestSlice() { TEST(Linalg, GPUElementWise) { TestElementWiseKernel(); } TEST(Linalg, GPUTensorView) { TestSlice(); } + +TEST(Linalg, GPUIter) { + auto ctx = MakeCUDACtx(1); + auto cuctx = ctx.CUDACtx(); + + dh::device_vector data(2 * 3 * 4); + thrust::sequence(cuctx->CTP(), data.begin(), data.end(), 1.0); + + auto t = MakeTensorView(&ctx, dh::ToSpan(data), 2, 3, 4); + static_assert(!std::is_const_v); + static_assert(!std::is_const_v); + + auto n = std::distance(linalg::tcbegin(t), linalg::tcend(t)); + ASSERT_EQ(n, t.Size()); + ASSERT_FALSE(t.Empty()); + + bool eq = thrust::equal(cuctx->CTP(), data.cbegin(), data.cend(), linalg::tcbegin(t)); + ASSERT_TRUE(eq); +} } // namespace xgboost::linalg