diff --git a/doc/parameter.rst b/doc/parameter.rst index 1c1afd1c0698..9bc328b45eb9 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -230,6 +230,20 @@ Parameters for Tree Booster list is a group of indices of features that are allowed to interact with each other. See tutorial for more information +Additional parameters for `gpu_hist` tree method +================================================ + +* ``single_precision_histogram``, [default=``false``] + + - Use single precision to build histograms. See document for GPU support for more details. + +* ``deterministic_histogram``, [default=``true``] + + - Build histogram on GPU deterministically. Histogram building is not deterministic due + to the non-associative aspect of floating point summation. We employ a pre-rounding + routine to mitigate the issue, which may lead to slightly lower accuracy. Set to + ``false`` to disable it. + Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 0b3a1e0af5f1..52ed13d29733 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -135,15 +135,15 @@ class GradientPairInternal { /*! \brief second order gradient statistics */ T hess_; - XGBOOST_DEVICE void SetGrad(float g) { grad_ = g; } - XGBOOST_DEVICE void SetHess(float h) { hess_ = h; } + XGBOOST_DEVICE void SetGrad(T g) { grad_ = g; } + XGBOOST_DEVICE void SetHess(T h) { hess_ = h; } public: using ValueT = T; XGBOOST_DEVICE GradientPairInternal() : grad_(0), hess_(0) {} - XGBOOST_DEVICE GradientPairInternal(float grad, float hess) { + XGBOOST_DEVICE GradientPairInternal(T grad, T hess) { SetGrad(grad); SetHess(hess); } @@ -160,8 +160,8 @@ class GradientPairInternal { SetHess(g.GetHess()); } - XGBOOST_DEVICE float GetGrad() const { return grad_; } - XGBOOST_DEVICE float GetHess() const { return hess_; } + XGBOOST_DEVICE T GetGrad() const { return grad_; } + XGBOOST_DEVICE T GetHess() const { return hess_; } XGBOOST_DEVICE GradientPairInternal &operator+=( const GradientPairInternal &rhs) { @@ -234,24 +234,6 @@ class GradientPairInternal { return os; } }; - -template<> -inline XGBOOST_DEVICE float GradientPairInternal::GetGrad() const { - return grad_ * 1e-4f; -} -template<> -inline XGBOOST_DEVICE float GradientPairInternal::GetHess() const { - return hess_ * 1e-4f; -} -template<> -inline XGBOOST_DEVICE void GradientPairInternal::SetGrad(float g) { - grad_ = static_cast(std::round(g * 1e4)); -} -template<> -inline XGBOOST_DEVICE void GradientPairInternal::SetHess(float h) { - hess_ = static_cast(std::round(h * 1e4)); -} - } // namespace detail /*! \brief gradient statistics pair usually needed in gradient boosting */ @@ -260,11 +242,6 @@ using GradientPair = detail::GradientPairInternal; /*! \brief High precision gradient statistics pair */ using GradientPairPrecise = detail::GradientPairInternal; -/*! \brief High precision gradient statistics pair with integer backed - * storage. Operators are associative where floating point versions are not - * associative. */ -using GradientPairInteger = detail::GradientPairInternal; - using Args = std::vector >; /*! \brief small eps gap for minimum split decision. */ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index f134c0399999..d4d24621084f 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1674,7 +1674,6 @@ def get_score(self, fmap='', importance_type='weight'): if importance_type == 'weight': # do a simpler tree dump to save time trees = self.get_dump(fmap, with_stats=False) - fmap = {} for tree in trees: for line in tree.split('\n'): diff --git a/python-package/xgboost/plotting.py b/python-package/xgboost/plotting.py index 7729515834b1..8fa306fcf3d9 100644 --- a/python-package/xgboost/plotting.py +++ b/python-package/xgboost/plotting.py @@ -68,7 +68,9 @@ def plot_importance(booster, ax=None, height=0.2, raise ValueError('tree must be Booster, XGBModel or dict instance') if not importance: - raise ValueError('Booster.get_score() results in empty') + raise ValueError( + 'Booster.get_score() results in empty. ' + + 'This maybe caused by having all trees as decision dumps.') tuples = [(k, importance[k]) for k in importance] if max_num_features is not None: diff --git a/src/common/observer.h b/src/common/observer.h index c047cc79b023..640c4ec470a0 100644 --- a/src/common/observer.h +++ b/src/common/observer.h @@ -16,12 +16,12 @@ #include "xgboost/base.h" #include "xgboost/tree_model.h" -#if defined(XGBOOST_STRICT_R_MODE) +#if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 #define OBSERVER_PRINT LOG(INFO) #define OBSERVER_ENDL "" #define OBSERVER_NEWLINE "" #else -#define OBSERVER_PRINT std::cout +#define OBSERVER_PRINT std::cout << std::setprecision(17) #define OBSERVER_ENDL std::endl #define OBSERVER_NEWLINE "\n" #endif // defined(XGBOOST_STRICT_R_MODE) diff --git a/src/data/ellpack_page_source.cc b/src/data/ellpack_page_source.cc index 5feb36022f1b..4f5453630563 100644 --- a/src/data/ellpack_page_source.cc +++ b/src/data/ellpack_page_source.cc @@ -29,14 +29,14 @@ bool EllpackPageSource::Next() { EllpackPage& EllpackPageSource::Value() { LOG(FATAL) << "Internal Error: " "XGBoost is not compiled with CUDA but EllpackPageSource is required"; - EllpackPage* page; + EllpackPage* page { nullptr }; return *page; } const EllpackPage& EllpackPageSource::Value() const { LOG(FATAL) << "Internal Error: " "XGBoost is not compiled with CUDA but EllpackPageSource is required"; - EllpackPage* page; + EllpackPage* page { nullptr }; return *page; } diff --git a/src/learner.cc b/src/learner.cc index 1c1eda095d6c..deafc589c9ba 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -734,6 +734,7 @@ class LearnerImpl : public Learner { monitor_.Start("PredictRaw"); this->PredictRaw(train.get(), &predt, true); + TrainingObserver::Instance().Observe(predt.predictions, "Predictions"); monitor_.Stop("PredictRaw"); monitor_.Start("GetGradient"); diff --git a/src/tree/gpu_hist/gradient_based_sampler.cuh b/src/tree/gpu_hist/gradient_based_sampler.cuh index 41099e3bc134..7e5ff1082333 100644 --- a/src/tree/gpu_hist/gradient_based_sampler.cuh +++ b/src/tree/gpu_hist/gradient_based_sampler.cuh @@ -25,6 +25,7 @@ class SamplingStrategy { public: /*! \brief Sample from a DMatrix based on the given gradient pairs. */ virtual GradientBasedSample Sample(common::Span gpair, DMatrix* dmat) = 0; + virtual ~SamplingStrategy() = default; }; /*! \brief No sampling in in-memory mode. */ diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu new file mode 100644 index 000000000000..88ebfec61552 --- /dev/null +++ b/src/tree/gpu_hist/histogram.cu @@ -0,0 +1,184 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#include +#include +#include +#include +#include + +#include "xgboost/base.h" +#include "row_partitioner.cuh" + +#include "histogram.cuh" + +#include "../../data/ellpack_page.cuh" +#include "../../common/device_helpers.cuh" + +namespace xgboost { +namespace tree { +// Following 2 functions are slightly modifed version of fbcuda. + +/* \brief Constructs a rounding factor used to truncate elements in a sum such that the + sum of the truncated elements is the same no matter what the order of the sum is. + + * Algorithm 5: Reproducible Sequential Sum in 'Fast Reproducible Floating-Point + * Summation' by Demmel and Nguyen + + * In algorithm 5 the bound is calculated as $max(|v_i|) * n$. Here we use the bound + * + * \begin{equation} + * max( fl(\sum^{V}_{v_i>0}{v_i}), fl(\sum^{V}_{v_i<0}|v_i|) ) + * \end{equation} + * + * to avoid outliers, as the full reduction is reproducible on GPU with reduction tree. + */ +template +DEV_INLINE __host__ T CreateRoundingFactor(T max_abs, int n) { + T delta = max_abs / (static_cast(1.0) - 2 * n * std::numeric_limits::epsilon()); + + // Calculate ceil(log_2(delta)). + // frexpf() calculates exp and returns `x` such that + // delta = x * 2^exp, where `x` in (-1.0, -0.5] U [0.5, 1). + // Because |x| < 1, exp is exactly ceil(log_2(delta)). + int exp; + std::frexp(delta, &exp); + + // return M = 2 ^ ceil(log_2(delta)) + return std::ldexp(static_cast(1.0), exp); +} + +namespace { +struct Pair { + GradientPair first; + GradientPair second; +}; +DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) { + return {lhs.first + rhs.first, lhs.second + rhs.second}; +} +} // anonymous namespace + +struct Clip : public thrust::unary_function { + static DEV_INLINE float Pclip(float v) { + return v > 0 ? v : 0; + } + static DEV_INLINE float Nclip(float v) { + return v < 0 ? abs(v) : 0; + } + + DEV_INLINE Pair operator()(GradientPair x) const { + auto pg = Pclip(x.GetGrad()); + auto ph = Pclip(x.GetHess()); + + auto ng = Nclip(x.GetGrad()); + auto nh = Nclip(x.GetHess()); + + return { GradientPair{ pg, ph }, GradientPair{ ng, nh } }; + } +}; + +template +GradientSumT CreateRoundingFactor(common::Span gpair) { + using T = typename GradientSumT::ValueT; + dh::XGBCachingDeviceAllocator alloc; + + thrust::device_ptr gpair_beg {gpair.data()}; + thrust::device_ptr gpair_end {gpair.data() + gpair.size()}; + auto beg = thrust::make_transform_iterator(gpair_beg, Clip()); + auto end = thrust::make_transform_iterator(gpair_end, Clip()); + Pair p = thrust::reduce(thrust::cuda::par(alloc), beg, end, Pair{}); + GradientPair positive_sum {p.first}, negative_sum {p.second}; + + auto histogram_rounding = GradientSumT { + CreateRoundingFactor(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), + gpair.size()), + CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), + gpair.size()) }; + return histogram_rounding; +} + +template GradientPairPrecise CreateRoundingFactor(common::Span gpair); +template GradientPair CreateRoundingFactor(common::Span gpair); + +template +__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, + common::Span d_ridx, + GradientSumT* __restrict__ d_node_hist, + const GradientPair* __restrict__ d_gpair, + size_t n_elements, + GradientSumT const rounding, + bool use_shared_memory_histograms) { + using T = typename GradientSumT::ValueT; + extern __shared__ char smem[]; + GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT + if (use_shared_memory_histograms) { + dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT()); + __syncthreads(); + } + for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { + int ridx = d_ridx[idx / matrix.info.row_stride]; + int gidx = + matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; + if (gidx != matrix.info.n_bins) { + GradientSumT truncated { + TruncateWithRoundingFactor(rounding.GetGrad(), d_gpair[ridx].GetGrad()), + TruncateWithRoundingFactor(rounding.GetHess(), d_gpair[ridx].GetHess()), + }; + // If we are not using shared memory, accumulate the values directly into + // global memory + GradientSumT* atomic_add_ptr = + use_shared_memory_histograms ? smem_arr : d_node_hist; + dh::AtomicAddGpair(atomic_add_ptr + gidx, truncated); + } + } + + if (use_shared_memory_histograms) { + // Write shared memory back to global memory + __syncthreads(); + for (auto i : dh::BlockStrideRange(static_cast(0), matrix.info.n_bins)) { + GradientSumT truncated { + TruncateWithRoundingFactor(rounding.GetGrad(), smem_arr[i].GetGrad()), + TruncateWithRoundingFactor(rounding.GetHess(), smem_arr[i].GetHess()), + }; + dh::AtomicAddGpair(d_node_hist + i, truncated); + } + } +} + +template +void BuildGradientHistogram(EllpackMatrix const& matrix, + common::Span gpair, + common::Span d_ridx, + common::Span histogram, + GradientSumT rounding, bool shared) { + const size_t smem_size = + shared + ? sizeof(GradientSumT) * matrix.info.n_bins + : 0; + auto n_elements = d_ridx.size() * matrix.info.row_stride; + + uint32_t items_per_thread = 8; + uint32_t block_threads = 256; + auto grid_size = static_cast( + common::DivRoundUp(n_elements, items_per_thread * block_threads)); + dh::LaunchKernel {grid_size, block_threads, smem_size} ( + SharedMemHistKernel, + matrix, d_ridx, histogram.data(), gpair.data(), n_elements, + rounding, shared); +} + +template void BuildGradientHistogram( + EllpackMatrix const& matrix, + common::Span gpair, + common::Span ridx, + common::Span histogram, + GradientPair rounding, bool shared); + +template void BuildGradientHistogram( + EllpackMatrix const& matrix, + common::Span gpair, + common::Span ridx, + common::Span histogram, + GradientPairPrecise rounding, bool shared); +} // namespace tree +} // namespace xgboost diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh new file mode 100644 index 000000000000..3cacae352507 --- /dev/null +++ b/src/tree/gpu_hist/histogram.cuh @@ -0,0 +1,29 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#ifndef HISTOGRAM_CUH_ +#define HISTOGRAM_CUH_ +#include +#include "../../data/ellpack_page.cuh" + +namespace xgboost { +namespace tree { + +template +GradientSumT CreateRoundingFactor(common::Span gpair); + +template +DEV_INLINE T TruncateWithRoundingFactor(T const rounding_factor, float const x) { + return (rounding_factor + static_cast(x)) - rounding_factor; +} + +template +void BuildGradientHistogram(EllpackMatrix const& matrix, + common::Span gpair, + common::Span ridx, + common::Span histogram, + GradientSumT rounding, bool shared); +} // namespace tree +} // namespace xgboost + +#endif // HISTOGRAM_CUH_ diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index b64eeaae6e70..81ce1819e199 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -91,6 +91,16 @@ struct DeviceSplitCandidate { } } XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; } + + friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) { + os << "loss_chg:" << c.loss_chg << ", " + << "dir: " << c.dir << ", " + << "findex: " << c.findex << ", " + << "fvalue: " << c.fvalue << ", " + << "left sum: " << c.left_sum << ", " + << "right sum: " << c.right_sum << std::endl; + return os; + } }; struct DeviceSplitCandidateReduceOp { @@ -186,6 +196,5 @@ struct SumCallbackOp { XGBOOST_DEVICE inline int MaxNodesDepth(int depth) { return (1 << (depth + 1)) - 1; } - } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 1d0592aa3595..4f02343dd636 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2017-2019 XGBoost contributors + * Copyright 2017-2020 XGBoost contributors */ #include #include @@ -31,10 +31,10 @@ #include "constraints.cuh" #include "gpu_hist/gradient_based_sampler.cuh" #include "gpu_hist/row_partitioner.cuh" +#include "gpu_hist/histogram.cuh" namespace xgboost { namespace tree { - #if !defined(GTEST_TEST) DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); #endif // !defined(GTEST_TEST) @@ -43,6 +43,7 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); struct GPUHistMakerTrainParam : public XGBoostParameter { bool single_precision_histogram; + bool deterministic_histogram; // number of rows in a single GPU batch int gpu_batch_nrows; bool debug_synchronize; @@ -50,6 +51,8 @@ struct GPUHistMakerTrainParam DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( "Use single precision to build histograms."); + DMLC_DECLARE_FIELD(deterministic_histogram).set_default(true).describe( + "Pre-round the gradient for obtaining deterministic gradient histogram."); DMLC_DECLARE_FIELD(gpu_batch_nrows) .set_lower_bound(-1) .set_default(0) @@ -336,6 +339,9 @@ class DeviceHistogram { bool HistogramExists(int nidx) const { return nidx_map_.find(nidx) != nidx_map_.cend(); } + int Bins() const { + return n_bins_; + } size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } @@ -402,40 +408,6 @@ struct CalcWeightTrainParam { learning_rate(p.learning_rate) {} }; -template -__global__ void SharedMemHistKernel(xgboost::EllpackMatrix matrix, - common::Span d_ridx, - GradientSumT* d_node_hist, - const GradientPair* d_gpair, size_t n_elements, - bool use_shared_memory_histograms) { - extern __shared__ char smem[]; - GradientSumT* smem_arr = reinterpret_cast(smem); // NOLINT - if (use_shared_memory_histograms) { - dh::BlockFill(smem_arr, matrix.info.n_bins, GradientSumT()); - __syncthreads(); - } - for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / matrix.info.row_stride]; - int gidx = - matrix.gidx_iter[ridx * matrix.info.row_stride + idx % matrix.info.row_stride]; - if (gidx != matrix.info.n_bins) { - // If we are not using shared memory, accumulate the values directly into - // global memory - GradientSumT* atomic_add_ptr = - use_shared_memory_histograms ? smem_arr : d_node_hist; - dh::AtomicAddGpair(atomic_add_ptr + gidx, d_gpair[ridx]); - } - } - - if (use_shared_memory_histograms) { - // Write shared memory back to global memory - __syncthreads(); - for (auto i : dh::BlockStrideRange(static_cast(0), matrix.info.n_bins)) { - dh::AtomicAddGpair(d_node_hist + i, smem_arr[i]); - } - } -} - // Manage memory for a single GPU template struct GPUHistMakerDevice { @@ -460,9 +432,12 @@ struct GPUHistMakerDevice { bst_uint n_rows; TrainParam param; + bool deterministic_histogram; bool prediction_cache_initialised; bool use_shared_memory_histograms {false}; + GradientSumT histogram_rounding; + dh::CubMemory temp_memory; dh::PinnedMemory pinned_memory; @@ -486,6 +461,7 @@ struct GPUHistMakerDevice { TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, + bool deterministic_histogram, BatchParam _batch_param) : device_id(_device_id), page(_page), @@ -494,6 +470,7 @@ struct GPUHistMakerDevice { prediction_cache_initialised(false), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), + deterministic_histogram{deterministic_histogram}, batch_param(_batch_param) { sampler.reset(new GradientBasedSampler(page, n_rows, @@ -551,6 +528,12 @@ struct GPUHistMakerDevice { page = sample.page; gpair = sample.gpair; + if (deterministic_histogram) { + histogram_rounding = CreateRoundingFactor(this->gpair); + } else { + histogram_rounding = GradientSumT{0.0, 0.0}; + } + row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, n_rows)); hist.Reset(); @@ -644,20 +627,8 @@ struct GPUHistMakerDevice { auto d_ridx = row_partitioner->GetRows(nidx); auto d_gpair = gpair.data(); - auto n_elements = d_ridx.size() * page->matrix.info.row_stride; - - const size_t smem_size = - use_shared_memory_histograms - ? sizeof(GradientSumT) * page->matrix.info.n_bins - : 0; - uint32_t items_per_thread = 8; - uint32_t block_threads = 256; - auto grid_size = static_cast( - common::DivRoundUp(n_elements, items_per_thread * block_threads)); - dh::LaunchKernel {grid_size, block_threads, smem_size} ( - SharedMemHistKernel, - page->matrix, d_ridx, d_node_hist.data(), d_gpair, n_elements, - use_shared_memory_histograms); + BuildGradientHistogram(page->matrix, gpair, d_ridx, d_node_hist, + histogram_rounding, use_shared_memory_histograms); } void SubtractionTrick(int nidx_parent, int nidx_histogram, @@ -707,7 +678,7 @@ struct GPUHistMakerDevice { // After tree update is finished, update the position of all training // instances to their final leaf. This information is used later to update the // prediction cache - void FinalisePosition(RegTree* p_tree, DMatrix* p_fmat) { + void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { const auto d_nodes = temp_memory.GetSpan(p_tree->GetNodes().size()); dh::safe_cuda(cudaMemcpy(d_nodes.data(), p_tree->GetNodes().data(), @@ -870,16 +841,21 @@ struct GPUHistMakerDevice { } void InitRoot(RegTree* p_tree, dh::AllReducer* reducer, int64_t num_columns) { - constexpr int kRootNIdx = 0; - - dh::SumReduction(temp_memory, gpair, node_sum_gradients_d, gpair.size()); + constexpr bst_node_t kRootNIdx = 0; + dh::XGBCachingDeviceAllocator alloc; + GradientPair root_sum = thrust::reduce( + thrust::cuda::par(alloc), + thrust::device_ptr(gpair.data()), + thrust::device_ptr(gpair.data() + gpair.size())); + dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients_d.data(), &root_sum, sizeof(root_sum), + cudaMemcpyHostToDevice)); reducer->AllReduceSum( reinterpret_cast(node_sum_gradients_d.data()), reinterpret_cast(node_sum_gradients_d.data()), 2); reducer->Synchronize(); - dh::safe_cuda(cudaMemcpy(node_sum_gradients.data(), - node_sum_gradients_d.data(), sizeof(GradientPair), - cudaMemcpyDeviceToHost)); + dh::safe_cuda(cudaMemcpyAsync(node_sum_gradients.data(), + node_sum_gradients_d.data(), sizeof(GradientPair), + cudaMemcpyDeviceToHost)); this->BuildHist(kRootNIdx); this->AllReduceHist(kRootNIdx, reducer); @@ -1055,6 +1031,7 @@ class GPUHistMakerSpecialised { param_, column_sampling_seed, info_->num_col_, + hist_maker_param_.deterministic_histogram, batch_param)); monitor_.StartCuda("InitHistogram"); diff --git a/tests/cpp/common/test_gpu_hist_util.cu b/tests/cpp/common/test_gpu_hist_util.cu index ff5683290e9b..9bbd404c7dc8 100644 --- a/tests/cpp/common/test_gpu_hist_util.cu +++ b/tests/cpp/common/test_gpu_hist_util.cu @@ -76,6 +76,20 @@ void TestDeviceSketch(bool use_external_memory) { ASSERT_LT(fabs(hmat_cpu.Values()[i] - hmat_gpu.Values()[i]), eps * nrows); } + // Determinstic + size_t constexpr kRounds { 100 }; + for (size_t r = 0; r < kRounds; ++r) { + HistogramCuts new_sketch; + DeviceSketch(device, max_bin, gpu_batch_nrows, dmat->get(), &new_sketch); + ASSERT_EQ(hmat_gpu.Values().size(), new_sketch.Values().size()); + for (size_t i = 0; i < hmat_gpu.Values().size(); ++i) { + ASSERT_EQ(hmat_gpu.Values()[i], new_sketch.Values()[i]); + } + for (size_t i = 0; i < hmat_gpu.MinValues().size(); ++i) { + ASSERT_EQ(hmat_gpu.MinValues()[i], new_sketch.MinValues()[i]); + } + } + delete dmat; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 8001076bae52..e2161933bc5a 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -223,9 +223,10 @@ inline GenericParameter CreateEmptyGenericParam(int gpu_id) { return tparam; } -inline HostDeviceVector GenerateRandomGradients(const size_t n_rows) { +inline HostDeviceVector GenerateRandomGradients(const size_t n_rows, + float lower= 0.0f, float upper = 1.0f) { xgboost::SimpleLCG gen; - xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); + xgboost::SimpleRealUniformDistribution dist(lower, upper); std::vector h_gpair(n_rows); for (auto &gpair : h_gpair) { bst_float grad = dist(&gen); @@ -287,6 +288,5 @@ inline std::unique_ptr BuildEllpackPage( return page; } #endif - } // namespace xgboost #endif diff --git a/tests/cpp/test_serialization.cc b/tests/cpp/test_serialization.cc index a6eb836c0cf2..9f69020318de 100644 --- a/tests/cpp/test_serialization.cc +++ b/tests/cpp/test_serialization.cc @@ -605,6 +605,10 @@ TEST_F(MultiClassesSerializationTest, GPU_Hist) { {"seed", "0"}, {"nthread", "1"}, {"max_depth", std::to_string(kClasses)}, + // Somehow rebuilding the cache can generate slightly + // different result (1e-7) with CPU predictor for some + // entries. + {"predictor", "gpu_predictor"}, {"enable_experimental_json_serialization", "1"}, {"tree_method", "gpu_hist"}}, fmap_, *pp_dmat_); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu new file mode 100644 index 000000000000..c57beb3c706a --- /dev/null +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -0,0 +1,69 @@ +#include +#include "../../helpers.h" +#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" +#include "../../../../src/tree/gpu_hist/histogram.cuh" + +namespace xgboost { +namespace tree { + +template +void TestDeterminsticHistogram() { + size_t constexpr kBins = 24, kCols = 8, kRows = 32768, kRounds = 16; + float constexpr kLower = -1e-2, kUpper = 1e2; + + auto pp_m = CreateDMatrix(kRows, kCols, 0.5); + auto& matrix = **pp_m; + BatchParam batch_param{0, static_cast(kBins), 0, 0}; + + for (auto const& batch : matrix.GetBatches(batch_param)) { + auto* page = batch.Impl(); + + tree::RowPartitioner row_partitioner(0, kRows); + auto ridx = row_partitioner.GetRows(0); + + dh::device_vector histogram(kBins * kCols); + auto d_histogram = dh::ToSpan(histogram); + auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); + gpair.SetDevice(0); + + auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); + BuildGradientHistogram(page->matrix, gpair.DeviceSpan(), ridx, + d_histogram, rounding, true); + + for (size_t i = 0; i < kRounds; ++i) { + dh::device_vector new_histogram(kBins * kCols); + auto d_histogram = dh::ToSpan(new_histogram); + + auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); + BuildGradientHistogram(page->matrix, gpair.DeviceSpan(), ridx, + d_histogram, rounding, true); + + for (size_t j = 0; j < new_histogram.size(); ++j) { + ASSERT_EQ(((Gradient)new_histogram[j]).GetGrad(), + ((Gradient)histogram[j]).GetGrad()); + ASSERT_EQ(((Gradient)new_histogram[j]).GetHess(), + ((Gradient)histogram[j]).GetHess()); + } + } + + { + auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); + gpair.SetDevice(0); + dh::device_vector baseline(kBins * kCols); + BuildGradientHistogram(page->matrix, gpair.DeviceSpan(), ridx, + dh::ToSpan(baseline), rounding, true); + for (size_t i = 0; i < baseline.size(); ++i) { + EXPECT_NEAR(((Gradient)baseline[i]).GetGrad(), ((Gradient)histogram[i]).GetGrad(), + ((Gradient)baseline[i]).GetGrad() * 1e-3); + } + } + } + delete pp_m; +} + +TEST(Histogram, GPUDeterminstic) { + TestDeterminsticHistogram(); + TestDeterminsticHistogram(); +} +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index a8d74d4d43af..8c38c09b8660 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -83,7 +83,8 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param); + GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols, + true, batch_param); maker.InitHistogram(); xgboost::SimpleLCG gen; @@ -187,7 +188,7 @@ TEST(GpuHist, EvaluateSplits) { auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; GPUHistMakerDevice - maker(0, page.get(), kNRows, param, kNCols, kNCols, batch_param); + maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {{6.4f, 12.8f}};