From a4795d1f8de3f5a388c611a948fa861c0e23b033 Mon Sep 17 00:00:00 2001 From: fis Date: Mon, 10 Aug 2020 04:13:06 +0800 Subject: [PATCH] Revert min value change. --- include/xgboost/span.h | 6 ---- src/common/common.h | 2 -- src/common/hist_util.cc | 28 ++++++++++++++++--- src/common/hist_util.cu | 3 ++ src/common/hist_util.h | 10 ++++++- src/common/quantile.cu | 10 +++++++ src/data/ellpack_page.cuh | 6 +++- src/data/ellpack_page_raw_format.cu | 2 ++ src/predictor/gpu_predictor.cu | 10 +++---- src/tree/gpu_hist/evaluate_splits.cu | 6 ++-- src/tree/gpu_hist/evaluate_splits.cuh | 1 + src/tree/updater_gpu_hist.cu | 3 ++ src/tree/updater_quantile_hist.cc | 3 +- tests/cpp/common/test_hist_util.cc | 4 +++ tests/cpp/common/test_hist_util.cu | 12 ++++++++ .../cpp/data/test_iterative_device_dmatrix.cu | 8 ++++++ tests/cpp/histogram_helpers.h | 4 +++ .../cpp/tree/gpu_hist/test_evaluate_splits.cu | 6 ++++ tests/cpp/tree/test_gpu_hist.cu | 1 + tests/cpp/tree/test_quantile_hist.cc | 2 +- tests/python-gpu/test_gpu_updaters.py | 2 +- 21 files changed, 103 insertions(+), 26 deletions(-) diff --git a/include/xgboost/span.h b/include/xgboost/span.h index b9ba4f321d0b..7cdabc5cc872 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -82,7 +82,6 @@ namespace common { "\tBlock: [%d, %d, %d], Thread: [%d, %d, %d]\n\n", \ __FILE__, __LINE__, __PRETTY_FUNCTION__, #cond, blockIdx.x, \ blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z); \ - assert(false); \ asm("trap;"); \ } \ } while (0); @@ -662,11 +661,6 @@ XGBOOST_DEVICE auto as_writable_bytes(Span s) __span_noexcept -> // NOLIN return {reinterpret_cast(s.data()), s.size_bytes()}; } -template class Container, typename... Types, - std::size_t Extent = dynamic_extent> -auto MakeSpan(Container const &container) { - return Span(container); -} } // namespace common } // namespace xgboost diff --git a/src/common/common.h b/src/common/common.h index ce754945b6ef..f9300f3d583c 100644 --- a/src/common/common.h +++ b/src/common/common.h @@ -82,8 +82,6 @@ XGBOOST_DEVICE T1 DivRoundUp(const T1 a, const T2 b) { return static_cast(std::ceil(static_cast(a) / b)); } -constexpr float kTrivialSplit = -std::numeric_limits::infinity(); - /* * Range iterator */ diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index ab761edacbb6..f8e42f2f454a 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -161,8 +161,10 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, // Data groups, used in ranking. std::vector const& group_ptr = info.group_ptr_; + auto &local_min_vals = p_cuts_->min_vals_.HostVector(); auto &local_cuts = p_cuts_->cut_values_.HostVector(); auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector(); + local_min_vals.resize(end_col - beg_col, 0); for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) { // Using a local variable makes things easier, but at the cost of memory trashing. @@ -197,11 +199,16 @@ void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, summary.Reserve(n_bins + 1); summary.SetPrune(out_summary, n_bins + 1); + // Can be use data[1] as the min values so that we don't need to + // store another array? + float mval = summary.data[0].value; + local_min_vals[col_id - beg_col] = mval - (fabs(mval) + 1e-5); + this->AddCutPoint(summary, max_num_bins); bst_float cpt = (summary.size > 0) ? summary.data[summary.size - 1].value : - kTrivialSplit; + local_min_vals[col_id - beg_col]; cpt += fabs(cpt) + 1e-5; local_cuts.emplace_back(cpt); @@ -279,10 +286,14 @@ void SparseCuts::Concat( std::vector> const& cuts, uint32_t n_cols) { monitor_.Start(__FUNCTION__); uint32_t nthreads = omp_get_max_threads(); + auto &local_min_vals = p_cuts_->min_vals_.HostVector(); auto &local_cuts = p_cuts_->cut_values_.HostVector(); auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector(); + local_min_vals.resize(n_cols, std::numeric_limits::max()); + size_t min_vals_tail = 0; for (uint32_t t = 0; t < nthreads; ++t) { + auto& thread_min_vals = cuts[t]->p_cuts_->min_vals_.HostVector(); auto& thread_cuts = cuts[t]->p_cuts_->cut_values_.HostVector(); auto& thread_ptrs = cuts[t]->p_cuts_->cut_ptrs_.HostVector(); @@ -303,6 +314,12 @@ void SparseCuts::Concat( for (size_t j = old_iv_size; j < new_iv_size; ++j) { local_cuts[j] = thread_cuts[j-old_iv_size]; } + // merge min values + for (size_t j = 0; j < thread_min_vals.size(); ++j) { + local_min_vals.at(min_vals_tail + j) = + std::min(local_min_vals.at(min_vals_tail + j), thread_min_vals.at(j)); + } + min_vals_tail += thread_min_vals.size(); } monitor_.Stop(__FUNCTION__); } @@ -409,15 +426,18 @@ void DenseCuts::Init // TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint // we need to move this allreduce before loadcheckpoint call in future sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); + p_cuts_->min_vals_.HostVector().resize(sketchs.size()); - for (auto const& summary : summary_array) { + for (size_t fid = 0; fid < summary_array.size(); ++fid) { WQSketch::SummaryContainer a; a.Reserve(max_num_bins + 1); - a.SetPrune(summary, max_num_bins + 1); + a.SetPrune(summary_array[fid], max_num_bins + 1); + const bst_float mval = a.data[0].value; + p_cuts_->min_vals_.HostVector()[fid] = mval - (fabs(mval) + 1e-5); AddCutPoint(a, max_num_bins); // push a value that is greater than anything const bst_float cpt - = (a.size > 0) ? a.data[a.size - 1].value : kTrivialSplit; + = (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_.HostVector()[fid]; // this must be bigger than last value in a scale const bst_float last = cpt + (fabs(cpt) + 1e-5); p_cuts_->cut_values_.HostVector().push_back(last); diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 90368087de3b..020cfb2a1350 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -158,6 +158,9 @@ size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, // 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) * // n_columns + n_columns + n_columns + 1 total += std::min(num_rows, num_bins) * num_columns * sizeof(float); + total += num_columns * + sizeof(std::remove_reference_t().MinValues())>::value_type); total += (num_columns + 1) * sizeof(std::remove_reference_t().Ptrs())>::value_type); diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 6065eef639a0..dbb0b35e4cea 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -47,13 +47,17 @@ class HistogramCuts { public: HostDeviceVector cut_values_; // NOLINT HostDeviceVector cut_ptrs_; // NOLINT + // storing minimum value in a sketch set. + HostDeviceVector min_vals_; // NOLINT HistogramCuts(); HistogramCuts(HistogramCuts const& that) { cut_values_.Resize(that.cut_values_.Size()); cut_ptrs_.Resize(that.cut_ptrs_.Size()); + min_vals_.Resize(that.min_vals_.Size()); cut_values_.Copy(that.cut_values_); cut_ptrs_.Copy(that.cut_ptrs_); + min_vals_.Copy(that.min_vals_); } HistogramCuts(HistogramCuts&& that) noexcept(true) { @@ -63,8 +67,10 @@ class HistogramCuts { HistogramCuts& operator=(HistogramCuts const& that) { cut_values_.Resize(that.cut_values_.Size()); cut_ptrs_.Resize(that.cut_ptrs_.Size()); + min_vals_.Resize(that.min_vals_.Size()); cut_values_.Copy(that.cut_values_); cut_ptrs_.Copy(that.cut_ptrs_); + min_vals_.Copy(that.min_vals_); return *this; } @@ -72,6 +78,7 @@ class HistogramCuts { monitor_ = std::move(that.monitor_); cut_ptrs_ = std::move(that.cut_ptrs_); cut_values_ = std::move(that.cut_values_); + min_vals_ = std::move(that.min_vals_); return *this; } @@ -88,14 +95,15 @@ class HistogramCuts { // these for now. std::vector const& Ptrs() const { return cut_ptrs_.ConstHostVector(); } std::vector const& Values() const { return cut_values_.ConstHostVector(); } + std::vector const& MinValues() const { return min_vals_.ConstHostVector(); } size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); } // Return the index of a cut point that is strictly greater than the input // value, or the last available index if none exists BinIdx SearchBin(float value, uint32_t column_id) const { + auto beg = cut_ptrs_.ConstHostVector().at(column_id); auto end = cut_ptrs_.ConstHostVector().at(column_id + 1); - auto beg = cut_ptrs_.ConstHostVector()[column_id]; const auto &values = cut_values_.ConstHostVector(); auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); BinIdx idx = it - values.cbegin(); diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 326630f92271..ab0ed5af0b47 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -499,6 +499,7 @@ void SketchContainer::AllReduce() { void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_)); + p_cuts->min_vals_.Resize(num_columns_); // Sync between workers. this->AllReduce(); @@ -510,6 +511,9 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { // Set up inputs auto d_in_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); + + p_cuts->min_vals_.SetDevice(device_); + auto d_min_values = p_cuts->min_vals_.DeviceSpan(); auto in_cut_values = dh::ToSpan(this->Current()); // Set up output ptr @@ -553,12 +557,18 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { // column is empty, trees cannot split on it. This is just to be consistent with // rest of the library. if (idx == 0) { + d_min_values[column_id] = kRtEps; out_column[0] = kRtEps; assert(out_column.size() == 1); } return; } + if (idx == 0 && !IsCat(d_ft, column_id)) { + auto mval = in_column[idx].value; + d_min_values[column_id] = mval - (fabs(mval) + 1e-5); + } + if (IsCat(d_ft, column_id)) { assert(out_column.size() == in_column.size()); out_column[idx] = in_column[idx].value; diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 26cc5038f025..0e83f7e6bb9d 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -52,6 +52,8 @@ struct EllpackDeviceAccessor { size_t base_rowid{}; size_t n_rows{}; common::CompressedIterator gidx_iter; + /*! \brief Minimum value for each feature. Size equals to number of features. */ + common::Span min_fvalue; /*! \brief Histogram cut pointers. Size equals to (number of features + 1). */ common::Span feature_segments; /*! \brief Histogram cut values. Size equals to (bins per feature * number of features). */ @@ -66,8 +68,10 @@ struct EllpackDeviceAccessor { n_rows(n_rows) ,gidx_iter(gidx_iter){ cuts.cut_values_.SetDevice(device); cuts.cut_ptrs_.SetDevice(device); + cuts.min_vals_.SetDevice(device); gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan(); feature_segments = cuts.cut_ptrs_.ConstDeviceSpan(); + min_fvalue = cuts.min_vals_.ConstDeviceSpan(); } // Get a matrix element, uses binary search for look up Return NaN if missing // Given a row index and a feature index, returns the corresponding cut value @@ -120,7 +124,7 @@ struct EllpackDeviceAccessor { XGBOOST_DEVICE size_t NumBins() const { return gidx_fvalue_map.size(); } - XGBOOST_DEVICE size_t NumFeatures() const { return feature_segments.size() - 1; } + XGBOOST_DEVICE size_t NumFeatures() const { return min_fvalue.size(); } }; diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 17380f40566a..d4caf37e2be3 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -19,6 +19,7 @@ class EllpackPageRawFormat : public SparsePageFormat { auto* impl = page->Impl(); fi->Read(&impl->Cuts().cut_values_.HostVector()); fi->Read(&impl->Cuts().cut_ptrs_.HostVector()); + fi->Read(&impl->Cuts().min_vals_.HostVector()); fi->Read(&impl->n_rows); fi->Read(&impl->is_dense); fi->Read(&impl->row_stride); @@ -39,6 +40,7 @@ class EllpackPageRawFormat : public SparsePageFormat { auto* impl = page.Impl(); fo->Write(impl->Cuts().cut_values_.ConstHostVector()); fo->Write(impl->Cuts().cut_ptrs_.ConstHostVector()); + fo->Write(impl->Cuts().min_vals_.ConstHostVector()); fo->Write(impl->n_rows); fo->Write(impl->is_dense); fo->Write(impl->row_stride); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 22c5e6f555e9..9068d69c0810 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -57,7 +57,7 @@ struct SparsePageLoader { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; int shared_elements = blockDim.x * num_features; dh::BlockFill(smem, shared_elements, nanf("")); - cub::CTA_SYNC(); + __syncthreads(); if (global_idx < num_rows) { bst_uint elem_begin = d_row_ptr[global_idx]; bst_uint elem_end = d_row_ptr[global_idx + 1]; @@ -66,7 +66,7 @@ struct SparsePageLoader { smem[threadIdx.x * num_features + elem.index] = elem.fvalue; } } - cub::CTA_SYNC(); + __syncthreads(); } } __device__ float GetFvalue(int ridx, int fidx) const { @@ -113,7 +113,7 @@ struct EllpackLoader { // The gradient index needs to be shifted by one as min values are not included in the // cuts. if (gidx == matrix.feature_segments[fidx]) { - return common::kTrivialSplit; + return matrix.min_fvalue[fidx]; } return matrix.gidx_fvalue_map[gidx - 1]; } @@ -140,7 +140,7 @@ struct DeviceAdapterLoader { uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; size_t shared_elements = blockDim.x * num_features; dh::BlockFill(smem, shared_elements, nanf("")); - cub::CTA_SYNC(); + __syncthreads(); if (global_idx < num_rows) { auto beg = global_idx * columns; auto end = (global_idx + 1) * columns; @@ -149,7 +149,7 @@ struct DeviceAdapterLoader { } } } - cub::CTA_SYNC(); + __syncthreads(); } DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const { diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 47d27c21c639..4006068b3af0 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -128,7 +128,7 @@ struct UpdateNumeric { int split_gidx = (scan_begin + threadIdx.x) - 1; float fvalue; if (split_gidx < static_cast(gidx_begin)) { - fvalue = common::kTrivialSplit; + fvalue = inputs.min_fvalue[fidx]; } else { fvalue = inputs.feature_values[split_gidx]; } @@ -180,7 +180,7 @@ __device__ void EvaluateFeature( inputs.value_constraint, missing_left); } - cub::CTA_SYNC(); + __syncthreads(); // Find thread with best gain cub::KeyValuePair tuple(threadIdx.x, gain); @@ -231,7 +231,7 @@ __global__ void EvaluateSplitsKernel( best_split = DeviceSplitCandidate(); } - cub::CTA_SYNC(); + __syncthreads(); // If this block is working on the left or right node bool is_left = blockIdx.x < left.feature_set.size(); diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index 3461795190b5..8ba177d8acdd 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -20,6 +20,7 @@ struct EvaluateSplitInputs { common::Span feature_types; common::Span feature_segments; common::Span feature_values; + common::Span min_fvalue; common::Span gradient_histogram; ValueConstraint value_constraint; common::Span monotonic_constraints; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 57489a1015df..ff3a5cfd34a8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -318,6 +318,7 @@ struct GPUHistMakerDevice { feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, + matrix.min_fvalue, hist.GetNodeHistogram(nidx), node_value_constraints[nidx], dh::ToSpan(monotone_constraints)}; @@ -356,6 +357,7 @@ struct GPUHistMakerDevice { feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, + matrix.min_fvalue, hist.GetNodeHistogram(left_nidx), node_value_constraints[left_nidx], dh::ToSpan(monotone_constraints)}; @@ -368,6 +370,7 @@ struct GPUHistMakerDevice { feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, + matrix.min_fvalue, hist.GetNodeHistogram(right_nidx), node_value_constraints[right_nidx], dh::ToSpan(monotone_constraints)}; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 1d38292faf8d..37a90dfebd74 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -25,7 +25,6 @@ #include "./updater_quantile_hist.h" #include "./split_evaluator.h" #include "../common/random.h" -#include "../common/common.h" #include "../common/hist_util.h" #include "../common/row_set.h" #include "../common/column_matrix.h" @@ -1334,7 +1333,7 @@ GradStats QuantileHistMaker::Builder::EnumerateSplit( snode.root_gain); if (i == imin) { // for leftmost bin, left bound is the smallest feature value - split_pt = common::kTrivialSplit; + split_pt = gmat.cut.MinValues()[fid]; } else { split_pt = cut_val[i - 1]; } diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 0c0d63ca691c..7a0ff9a47215 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -190,6 +190,7 @@ TEST(SparseCuts, SingleThreadedBuild) { ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size()); ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs()); ASSERT_EQ(hmat.cut.Values(), cuts.Values()); + ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues()); } TEST(SparseCuts, MultiThreadedBuild) { @@ -253,6 +254,7 @@ TEST(HistUtil, DenseCutsCategorical) { DenseCuts dense(&cuts); dense.Build(dmat.get(), num_bins); auto cuts_from_sketch = cuts.Values(); + EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GE(cuts_from_sketch.back(), x_sorted.back()); EXPECT_EQ(cuts_from_sketch.size(), num_categories); @@ -342,6 +344,7 @@ TEST(HistUtil, SparseCutsCategorical) { SparseCuts sparse(&cuts); sparse.Build(dmat.get(), num_bins); auto cuts_from_sketch = cuts.Values(); + EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); EXPECT_GE(cuts_from_sketch.back(), x_sorted.back()); EXPECT_EQ(cuts_from_sketch.size(), num_categories); @@ -464,5 +467,6 @@ TEST(HistUtil, SparseIndexBinData) { } } } + } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index c1f92982b88d..ca41c1a9b218 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -45,6 +45,7 @@ TEST(HistUtil, DeviceSketch) { EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); + EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); } TEST(HistUtil, SketchBatchNumElements) { @@ -107,6 +108,7 @@ TEST(HistUtil, DeviceSketchDeterminism) { for (size_t r = 0; r < kRounds; ++r) { auto new_sketch = DeviceSketch(0, dmat.get(), num_bins); ASSERT_EQ(reference_sketch.Values(), new_sketch.Values()); + ASSERT_EQ(reference_sketch.MinValues(), new_sketch.MinValues()); } } @@ -194,6 +196,7 @@ TEST(HistUitl, DeviceSketchWeights) { for (auto num_bins : bin_sizes) { auto cuts = DeviceSketch(0, dmat.get(), num_bins); auto wcuts = DeviceSketch(0, weighted_dmat.get(), num_bins); + ASSERT_EQ(cuts.MinValues(), wcuts.MinValues()); ASSERT_EQ(cuts.Ptrs(), wcuts.Ptrs()); ASSERT_EQ(cuts.Values(), wcuts.Values()); ValidateCuts(cuts, dmat.get(), num_bins); @@ -300,6 +303,7 @@ TEST(HistUtil, AdapterDeviceSketch) { EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); + EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); } TEST(HistUtil, AdapterDeviceSketchMemory) { @@ -438,6 +442,7 @@ TEST(HistUtil, SketchingEquivalent) { adapter, num_bins, std::numeric_limits::quiet_NaN()); EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values()); EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs()); + EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues()); ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get()); } @@ -462,11 +467,15 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { HistogramCuts cuts = DeviceSketch(0, m.get(), kBins, 0); ASSERT_EQ(cuts.Values().size(), weighted_cuts.Values().size()); + ASSERT_EQ(cuts.MinValues().size(), weighted_cuts.MinValues().size()); ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size()); for (size_t i = 0; i < cuts.Values().size(); ++i) { EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i; } + for (size_t i = 0; i < cuts.MinValues().size(); ++i) { + ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]); + } for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i)); } @@ -523,6 +532,9 @@ void TestAdapterSketchFromWeights(bool with_group) { for (size_t i = 0; i < cuts.Values().size(); ++i) { EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); } + for (size_t i = 0; i < cuts.MinValues().size(); ++i) { + ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]); + } for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); } diff --git a/tests/cpp/data/test_iterative_device_dmatrix.cu b/tests/cpp/data/test_iterative_device_dmatrix.cu index d076f8c7850e..e440103365e4 100644 --- a/tests/cpp/data/test_iterative_device_dmatrix.cu +++ b/tests/cpp/data/test_iterative_device_dmatrix.cu @@ -39,19 +39,27 @@ void TestEquivalent(float sparsity) { auto from_data = ellpack.Impl()->GetDeviceAccessor(0); std::vector cuts_from_iter(from_iter.gidx_fvalue_map.size()); + std::vector min_fvalues_iter(from_iter.min_fvalue.size()); std::vector cut_ptrs_iter(from_iter.feature_segments.size()); dh::CopyDeviceSpanToVector(&cuts_from_iter, from_iter.gidx_fvalue_map); + dh::CopyDeviceSpanToVector(&min_fvalues_iter, from_iter.min_fvalue); dh::CopyDeviceSpanToVector(&cut_ptrs_iter, from_iter.feature_segments); std::vector cuts_from_data(from_data.gidx_fvalue_map.size()); + std::vector min_fvalues_data(from_data.min_fvalue.size()); std::vector cut_ptrs_data(from_data.feature_segments.size()); dh::CopyDeviceSpanToVector(&cuts_from_data, from_data.gidx_fvalue_map); + dh::CopyDeviceSpanToVector(&min_fvalues_data, from_data.min_fvalue); dh::CopyDeviceSpanToVector(&cut_ptrs_data, from_data.feature_segments); ASSERT_EQ(cuts_from_iter.size(), cuts_from_data.size()); for (size_t i = 0; i < cuts_from_iter.size(); ++i) { EXPECT_NEAR(cuts_from_iter[i], cuts_from_data[i], kRtEps); } + ASSERT_EQ(min_fvalues_iter.size(), min_fvalues_data.size()); + for (size_t i = 0; i < min_fvalues_iter.size(); ++i) { + ASSERT_NEAR(min_fvalues_iter[i], min_fvalues_data[i], kRtEps); + } ASSERT_EQ(cut_ptrs_iter.size(), cut_ptrs_data.size()); for (size_t i = 0; i < cut_ptrs_iter.size(); ++i) { ASSERT_EQ(cut_ptrs_iter[i], cut_ptrs_data[i]); diff --git a/tests/cpp/histogram_helpers.h b/tests/cpp/histogram_helpers.h index 3158f20e8cbc..013020784045 100644 --- a/tests/cpp/histogram_helpers.h +++ b/tests/cpp/histogram_helpers.h @@ -14,6 +14,9 @@ class HistogramCutsWrapper : public common::HistogramCuts { void SetPtrs(std::vector ptrs) { SuperT::cut_ptrs_.HostVector() = std::move(ptrs); } + void SetMins(std::vector mins) { + SuperT::min_vals_.HostVector() = std::move(mins); + } }; } // anonymous namespace @@ -33,6 +36,7 @@ inline std::unique_ptr BuildEllpackPage( 0.25f, 0.74f, 2.00f, 0.26f, 0.74f, 1.98f, 0.26f, 0.71f, 1.83f}); + cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); bst_row_t row_stride = 0; const auto &offset_vec = batch.offset.ConstHostVector(); diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index d734abf5d8f9..4b9670d83ee6 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -38,6 +38,7 @@ void TestEvaluateSingleSplit(bool is_categorical) { d_feature_types, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), + dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; @@ -82,6 +83,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), + dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; @@ -134,6 +136,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), + dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; @@ -171,6 +174,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), + dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; @@ -209,6 +213,7 @@ TEST(GpuHist, EvaluateSplits) { {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), + dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram_left), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; @@ -220,6 +225,7 @@ TEST(GpuHist, EvaluateSplits) { {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), + dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram_right), ValueConstraint(), dh::ToSpan(monotonic_constraints)}; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index f2c3bc5d4e3b..fad6466ce058 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -133,6 +133,7 @@ TEST(GpuHist, BuildHistSharedMem) { HistogramCutsWrapper GetHostCutMatrix () { HistogramCutsWrapper cmat; cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); + cmat.SetMins({0.1f, 0.2f, 0.3f, 0.1f, 0.2f, 0.3f, 0.2f, 0.2f}); // 24 cut fields, 3 cut fields for each feature (column). // Each row of the cut represents the cuts for a data column. cmat.SetValues({0.30f, 0.67f, 1.64f, diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 03767e3139a9..1b6ab89e9992 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -93,7 +93,7 @@ class QuantileHistMock : public QuantileHistMaker { if (bin_id > gmat.cut.Ptrs()[fid]) { ASSERT_GE(inst[j].fvalue, gmat.cut.Values()[bin_id - 1]); } else { - ASSERT_NE(inst[j].fvalue, common::kTrivialSplit); + ASSERT_GE(inst[j].fvalue, gmat.cut.MinValues()[fid]); } } } diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 1e7b29a140fa..a1b607865305 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -65,7 +65,7 @@ def run_categorical_basic(self, cat, onehot, label, rounds): assert tm.non_increasing(by_builtin_results['Train']['rmse']) @given(strategies.integers(10, 400), strategies.integers(5, 10), - strategies.integers(4, 10), strategies.integers(4, 8)) + strategies.integers(1, 6), strategies.integers(4, 8)) @settings(deadline=None) @pytest.mark.skipif(**tm.no_pandas()) def test_categorical(self, rows, cols, rounds, cats):