From 655cf17b60538a8655ca2abff29c74d3ce3a995c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Sun, 23 Feb 2020 06:27:03 +0800 Subject: [PATCH] Predict on Ellpack. (#5327) * Unify GPU prediction node. * Add `PageExists`. * Dispatch prediction on input data for GPU Predictor. --- include/xgboost/data.h | 27 ++- include/xgboost/tree_model.h | 8 +- src/data/ellpack_page.cu | 4 +- src/data/ellpack_page.cuh | 9 +- src/data/simple_dmatrix.cc | 10 +- src/data/simple_dmatrix.h | 8 + src/data/sparse_page_dmatrix.cc | 4 +- src/data/sparse_page_dmatrix.h | 7 + src/predictor/gpu_predictor.cu | 206 ++++++++++----------- src/tree/tree_model.cc | 4 +- src/tree/updater_gpu_hist.cu | 4 +- src/tree/updater_refresh.cc | 2 +- tests/cpp/data/test_ellpack_page.cu | 2 +- tests/cpp/data/test_sparse_page_dmatrix.cu | 2 +- tests/cpp/helpers.h | 1 - tests/cpp/predictor/test_gpu_predictor.cu | 18 +- tests/cpp/predictor/test_predictor.cc | 55 +++++- tests/cpp/predictor/test_predictor.h | 70 +++++++ tests/python/regression_test_utilities.py | 13 ++ 19 files changed, 320 insertions(+), 134 deletions(-) create mode 100644 tests/cpp/predictor/test_predictor.h diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 2894ebafaadd..3057dcdc965c 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -168,12 +168,19 @@ struct BatchParam { /*! \brief The GPU device to use. */ int gpu_id; /*! \brief Maximum number of bins per feature for histograms. */ - int max_bin; + int max_bin { 0 }; /*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */ int gpu_batch_nrows; /*! \brief Page size for external memory mode. */ size_t gpu_page_size; - + BatchParam() = default; + BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows, + size_t gpu_page_size = 0) : + gpu_id{device}, + max_bin{max_bin}, + gpu_batch_nrows{gpu_batch_nrows}, + gpu_page_size{gpu_page_size} + {} inline bool operator!=(const BatchParam& other) const { return gpu_id != other.gpu_id || max_bin != other.max_bin || @@ -438,6 +445,9 @@ class DMatrix { */ template BatchSet GetBatches(const BatchParam& param = {}); + template + bool PageExists() const; + // the following are column meta data, should be able to answer them fast. /*! \return Whether the data columns single column block. */ virtual bool SingleColBlock() const = 0; @@ -493,6 +503,9 @@ class DMatrix { virtual BatchSet GetColumnBatches() = 0; virtual BatchSet GetSortedColumnBatches() = 0; virtual BatchSet GetEllpackBatches(const BatchParam& param) = 0; + + virtual bool EllpackExists() const = 0; + virtual bool SparsePageExists() const = 0; }; template<> @@ -500,6 +513,16 @@ inline BatchSet DMatrix::GetBatches(const BatchParam&) { return GetRowBatches(); } +template<> +inline bool DMatrix::PageExists() const { + return this->EllpackExists(); +} + +template<> +inline bool DMatrix::PageExists() const { + return this->SparsePageExists(); +} + template<> inline BatchSet DMatrix::GetBatches(const BatchParam&) { return GetColumnBatches(); diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 69f3a718b662..5c0b1caadfea 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -105,7 +105,7 @@ class RegTree : public Model { /*! \brief tree node */ class Node { public: - Node() { + XGBOOST_DEVICE Node() { // assert compact alignment static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info), "Node: 64 bit align"); @@ -422,7 +422,7 @@ class RegTree : public Model { * \param i feature index. * \return the i-th feature value */ - bst_float Fvalue(size_t i) const; + bst_float GetFvalue(size_t i) const; /*! * \brief check whether i-th entry is missing * \param i feature index. @@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const { return data_.size(); } -inline bst_float RegTree::FVec::Fvalue(size_t i) const { +inline bst_float RegTree::FVec::GetFvalue(size_t i) const { return data_[i].fvalue; } @@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const { bst_node_t nid = 0; while (!(*this)[nid].IsLeaf()) { unsigned split_index = (*this)[nid].SplitIndex(); - nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index)); + nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index)); } return nid; } diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 760d47b06f8d..ff86d77aada2 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel( common::CompressedByteT* __restrict__ buffer, // gidx_buffer const size_t* __restrict__ row_ptrs, // row offset of input data const Entry* __restrict__ entries, // One batch of input data - const float* __restrict__ cuts, // HistogramCuts::cut - const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs + const float* __restrict__ cuts, // HistogramCuts::cut_values_ + const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_ size_t base_row, // batch_row_begin size_t n_rows, size_t row_stride, diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index fcf89ab8fe98..4d3c7a185ba3 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -76,6 +76,9 @@ struct EllpackInfo { size_t NumSymbols() const { return n_bins + 1; } + size_t NumFeatures() const { + return min_fvalue.size(); + } }; /** \brief Struct for accessing and manipulating an ellpack matrix on the @@ -89,7 +92,7 @@ struct EllpackMatrix { // Get a matrix element, uses binary search for look up Return NaN if missing // Given a row index and a feature index, returns the corresponding cut value - __device__ bst_float GetElement(size_t ridx, size_t fidx) const { + __device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const { ridx -= base_rowid; auto row_begin = info.row_stride * ridx; auto row_end = row_begin + info.row_stride; @@ -103,6 +106,10 @@ struct EllpackMatrix { info.feature_segments[fidx], info.feature_segments[fidx + 1]); } + return gidx; + } + __device__ bst_float GetFvalue(size_t ridx, size_t fidx) const { + auto gidx = GetBinIndex(ridx, fidx); if (gidx == -1) { return nan(""); } diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 8b3d241d6bbd..345f20fd4ee0 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -61,11 +61,15 @@ BatchSet SimpleDMatrix::GetSortedColumnBatches() { } BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) { - CHECK_GE(param.gpu_id, 0); - CHECK_GE(param.max_bin, 2); // ELLPACK page doesn't exist, generate it - if (!ellpack_page_) { + if (!(batch_param_ != BatchParam{})) { + CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; + } + if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) { + CHECK_GE(param.gpu_id, 0); + CHECK_GE(param.max_bin, 2); ellpack_page_.reset(new EllpackPage(this, param)); + batch_param_ = param; } auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_page_.get())); diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 89ca77856179..2fdd53036366 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix { std::unique_ptr column_page_; std::unique_ptr sorted_column_page_; std::unique_ptr ellpack_page_; + BatchParam batch_param_; + + bool EllpackExists() const override { + return static_cast(ellpack_page_); + } + bool SparsePageExists() const override { + return true; + } }; } // namespace data } // namespace xgboost diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index e04bb0c8c05c..be0730ceb69b 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2014 by Contributors + * Copyright 2014-2020 by Contributors * \file sparse_page_dmatrix.cc * \brief The external memory version of Page Iterator. * \author Tianqi Chen @@ -47,7 +47,7 @@ BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& par CHECK_GE(param.gpu_id, 0); CHECK_GE(param.max_bin, 2); // Lazily instantiate - if (!ellpack_source_ || batch_param_ != param) { + if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) { ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param)); batch_param_ = param; } diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index a227cecccacb..46a48a8be6bf 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix { std::string cache_info_; // Store column densities to avoid recalculating std::vector col_density_; + + bool EllpackExists() const override { + return static_cast(ellpack_source_); + } + bool SparsePageExists() const override { + return static_cast(row_source_); + } }; } // namespace data } // namespace xgboost diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 75917ccdca3c..5bbda52af21c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -14,6 +14,7 @@ #include "xgboost/host_device_vector.h" #include "../gbm/gbtree_model.h" +#include "../data/ellpack_page.cuh" #include "../common/common.h" #include "../common/device_helpers.cuh" @@ -22,78 +23,32 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(gpu_predictor); -/** - * \struct DevicePredictionNode - * - * \brief Packed 16 byte representation of a tree node for use in device - * prediction - */ -struct DevicePredictionNode { - XGBOOST_DEVICE DevicePredictionNode() - : fidx{-1}, left_child_idx{-1}, right_child_idx{-1} {} - - union NodeValue { - float leaf_weight; - float fvalue; - }; - - int fidx; - int left_child_idx; - int right_child_idx; - NodeValue val{}; - - DevicePredictionNode(const RegTree::Node& n) { // NOLINT - static_assert(sizeof(DevicePredictionNode) == 16, "Size is not 16 bytes"); - this->left_child_idx = n.LeftChild(); - this->right_child_idx = n.RightChild(); - this->fidx = n.SplitIndex(); - if (n.DefaultLeft()) { - fidx |= (1U << 31); - } - - if (n.IsLeaf()) { - this->val.leaf_weight = n.LeafValue(); - } else { - this->val.fvalue = n.SplitCond(); - } - } - - XGBOOST_DEVICE bool IsLeaf() const { return left_child_idx == -1; } - - XGBOOST_DEVICE int GetFidx() const { return fidx & ((1U << 31) - 1U); } - - XGBOOST_DEVICE bool MissingLeft() const { return (fidx >> 31) != 0; } - - XGBOOST_DEVICE int MissingIdx() const { - if (MissingLeft()) { - return this->left_child_idx; - } else { - return this->right_child_idx; - } - } - - XGBOOST_DEVICE float GetFvalue() const { return val.fvalue; } +struct SparsePageView { + common::Span d_data; + common::Span d_row_ptr; - XGBOOST_DEVICE float GetWeight() const { return val.leaf_weight; } + XGBOOST_DEVICE SparsePageView(common::Span data, + common::Span row_ptr) : + d_data{data}, d_row_ptr{row_ptr} {} }; -struct ElementLoader { +struct SparsePageLoader { bool use_shared; common::Span d_row_ptr; common::Span d_data; - int num_features; + bst_feature_t num_features; float* smem; size_t entry_start; - __device__ ElementLoader(bool use_shared, common::Span row_ptr, - common::Span entry, int num_features, - float* smem, int num_rows, size_t entry_start) + __device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features, + bst_row_t num_rows, size_t entry_start) : use_shared(use_shared), - d_row_ptr(row_ptr), - d_data(entry), + d_row_ptr(data.d_row_ptr), + d_data(data.d_data), num_features(num_features), - smem(smem), entry_start(entry_start) { + extern __shared__ float _smem[]; + smem = _smem; // Copy instances if (use_shared) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; @@ -111,7 +66,7 @@ struct ElementLoader { __syncthreads(); } } - __device__ float GetFvalue(int ridx, int fidx) { + __device__ float GetFvalue(int ridx, int fidx) const { if (use_shared) { return smem[threadIdx.x * num_features + fidx]; } else { @@ -141,52 +96,69 @@ struct ElementLoader { } }; -__device__ float GetLeafWeight(bst_uint ridx, const DevicePredictionNode* tree, - ElementLoader* loader) { - DevicePredictionNode n = tree[0]; +struct EllpackLoader { + EllpackMatrix const& matrix; + XGBOOST_DEVICE EllpackLoader(EllpackMatrix const& m, bool use_shared, bst_feature_t num_features, + bst_row_t num_rows, size_t entry_start) : matrix{m} {} + __device__ __forceinline__ float GetFvalue(int ridx, int fidx) const { + auto gidx = matrix.GetBinIndex(ridx, fidx); + if (gidx == -1) { + return nan(""); + } + // The gradient index needs to be shifted by one as min values are not included in the + // cuts. + if (gidx == matrix.info.feature_segments[fidx]) { + return matrix.info.min_fvalue[fidx]; + } + return matrix.info.gidx_fvalue_map[gidx - 1]; + } +}; + +template +__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, + Loader* loader) { + RegTree::Node n = tree[0]; while (!n.IsLeaf()) { - float fvalue = loader->GetFvalue(ridx, n.GetFidx()); + float fvalue = loader->GetFvalue(ridx, n.SplitIndex()); // Missing value if (isnan(fvalue)) { - n = tree[n.MissingIdx()]; + n = tree[n.DefaultChild()]; } else { - if (fvalue < n.GetFvalue()) { - n = tree[n.left_child_idx]; + if (fvalue < n.SplitCond()) { + n = tree[n.LeftChild()]; } else { - n = tree[n.right_child_idx]; + n = tree[n.RightChild()]; } } } - return n.GetWeight(); + return n.LeafValue(); } -template -__global__ void PredictKernel(common::Span d_nodes, +template +__global__ void PredictKernel(Data data, + common::Span d_nodes, common::Span d_out_predictions, common::Span d_tree_segments, common::Span d_tree_group, - common::Span d_row_ptr, - common::Span d_data, size_t tree_begin, - size_t tree_end, size_t num_features, + size_t tree_begin, size_t tree_end, size_t num_features, size_t num_rows, size_t entry_start, bool use_shared, int num_group) { - extern __shared__ float smem[]; bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; - ElementLoader loader(use_shared, d_row_ptr, d_data, num_features, smem, - num_rows, entry_start); + Loader loader(data, use_shared, num_features, num_rows, entry_start); if (global_idx >= num_rows) return; if (num_group == 1) { float sum = 0; for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - const DevicePredictionNode* d_tree = + const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; - sum += GetLeafWeight(global_idx, d_tree, &loader); + float leaf = GetLeafWeight(global_idx, d_tree, &loader); + sum += leaf; } d_out_predictions[global_idx] += sum; } else { for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { int tree_group = d_tree_group[tree_idx]; - const DevicePredictionNode* d_tree = + const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; bst_uint out_prediction_idx = global_idx * num_group + tree_group; d_out_predictions[out_prediction_idx] += @@ -198,13 +170,13 @@ __global__ void PredictKernel(common::Span d_nodes, class GPUPredictor : public xgboost::Predictor { private: void InitModel(const gbm::GBTreeModel& model, - const thrust::host_vector& h_tree_segments, - const thrust::host_vector& h_nodes, - size_t tree_begin, size_t tree_end) { + const thrust::host_vector& h_tree_segments, + const thrust::host_vector& h_nodes, + size_t tree_begin, size_t tree_end) { dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); nodes_.resize(h_nodes.size()); dh::safe_cuda(cudaMemcpyAsync(nodes_.data().get(), h_nodes.data(), - sizeof(DevicePredictionNode) * h_nodes.size(), + sizeof(RegTree::Node) * h_nodes.size(), cudaMemcpyHostToDevice)); tree_segments_.resize(h_tree_segments.size()); dh::safe_cuda(cudaMemcpyAsync(tree_segments_.data().get(), h_tree_segments.data(), @@ -219,15 +191,11 @@ class GPUPredictor : public xgboost::Predictor { this->num_group_ = model.learner_model_param_->num_output_group; } - void PredictInternal(const SparsePage& batch, - size_t num_features, + void PredictInternal(const SparsePage& batch, size_t num_features, HostDeviceVector* predictions, size_t batch_offset) { - dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); - batch.data.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id); - predictions->SetDevice(generic_param_->gpu_id); - + batch.data.SetDevice(generic_param_->gpu_id); const uint32_t BLOCK_THREADS = 128; size_t num_rows = batch.Size(); auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); @@ -240,12 +208,29 @@ class GPUPredictor : public xgboost::Predictor { use_shared = false; } size_t entry_start = 0; - + SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()}; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, + PredictKernel, + data, dh::ToSpan(nodes_), predictions->DeviceSpan().subspan(batch_offset), - dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), batch.offset.DeviceSpan(), - batch.data.DeviceSpan(), this->tree_begin_, this->tree_end_, num_features, num_rows, + dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), + this->tree_begin_, this->tree_end_, num_features, num_rows, + entry_start, use_shared, this->num_group_); + } + void PredictInternal(EllpackMatrix const& batch, HostDeviceVector* out_preds, + size_t batch_offset) { + const uint32_t BLOCK_THREADS = 256; + size_t num_rows = batch.n_rows; + auto GRID_SIZE = static_cast(common::DivRoundUp(num_rows, BLOCK_THREADS)); + + bool use_shared = false; + size_t entry_start = 0; + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( + PredictKernel, + batch, + dh::ToSpan(nodes_), out_preds->DeviceSpan().subspan(batch_offset), + dh::ToSpan(tree_segments_), dh::ToSpan(tree_group_), + this->tree_begin_, this->tree_end_, batch.info.NumFeatures(), num_rows, entry_start, use_shared, this->num_group_); } @@ -261,7 +246,7 @@ class GPUPredictor : public xgboost::Predictor { h_tree_segments.push_back(sum); } - thrust::host_vector h_nodes(h_tree_segments.back()); + thrust::host_vector h_nodes(h_tree_segments.back()); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); std::copy(src_nodes.begin(), src_nodes.end(), @@ -270,26 +255,31 @@ class GPUPredictor : public xgboost::Predictor { InitModel(model, h_tree_segments, h_nodes, tree_begin, tree_end); } - void DevicePredictInternal(DMatrix* dmat, - HostDeviceVector* out_preds, + void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, const gbm::GBTreeModel& model, size_t tree_begin, size_t tree_end) { + dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id)); if (tree_end - tree_begin == 0) { return; } monitor_.StartCuda("DevicePredictInternal"); - InitModel(model, tree_begin, tree_end); + out_preds->SetDevice(generic_param_->gpu_id); - size_t batch_offset = 0; - for (auto &batch : dmat->GetBatches()) { - batch.offset.SetDevice(generic_param_->gpu_id); - batch.data.SetDevice(generic_param_->gpu_id); - PredictInternal(batch, model.learner_model_param_->num_feature, - out_preds, batch_offset); - batch_offset += batch.Size() * model.learner_model_param_->num_output_group; + if (dmat->PageExists()) { + size_t batch_offset = 0; + for (auto const& page : dmat->GetBatches()) { + this->PredictInternal(page.Impl()->matrix, out_preds, batch_offset); + batch_offset += page.Impl()->matrix.n_rows; + } + } else { + size_t batch_offset = 0; + for (auto &batch : dmat->GetBatches()) { + this->PredictInternal(batch, model.learner_model_param_->num_feature, + out_preds, batch_offset); + batch_offset += batch.Size() * model.learner_model_param_->num_output_group; + } } - monitor_.StopCuda("DevicePredictInternal"); } @@ -418,7 +408,7 @@ class GPUPredictor : public xgboost::Predictor { } common::Monitor monitor_; - dh::device_vector nodes_; + dh::device_vector nodes_; dh::device_vector tree_segments_; dh::device_vector tree_group_; size_t max_shared_memory_bytes_; diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index f8d24a6ce2ec..f6cdc06e2d19 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -792,7 +792,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, bst_node_t nid = 0; while (!(*this)[nid].IsLeaf()) { split_index = (*this)[nid].SplitIndex(); - nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index)); + nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index)); bst_float new_value = this->node_mean_values_[nid]; // update feature weight out_contribs[split_index] += new_value - node_value; @@ -924,7 +924,7 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi, unsigned hot_index = 0; if (feat.IsMissing(split_index)) { hot_index = node.DefaultChild(); - } else if (feat.Fvalue(split_index) < node.SplitCond()) { + } else if (feat.GetFvalue(split_index) < node.SplitCond()) { hot_index = node.LeftChild(); } else { hot_index = node.RightChild(); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index eafa227b05e3..1d0592aa3595 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -688,7 +688,7 @@ struct GPUHistMakerDevice { [=] __device__(bst_uint ridx) { // given a row index, returns the node id it belongs to bst_float cut_value = - d_matrix.GetElement(ridx, split_node.SplitIndex()); + d_matrix.GetFvalue(ridx, split_node.SplitIndex()); // Missing value int new_position = 0; if (isnan(cut_value)) { @@ -737,7 +737,7 @@ struct GPUHistMakerDevice { auto node = d_nodes[position]; while (!node.IsLeaf()) { - bst_float element = d_matrix.GetElement(row_id, node.SplitIndex()); + bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); // Missing value if (isnan(element)) { position = node.DefaultChild(); diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index feb831e5b03f..c954712d71aa 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -119,7 +119,7 @@ class TreeRefresher: public TreeUpdater { // tranverse tree while (!tree[pid].IsLeaf()) { unsigned split_index = tree[pid].SplitIndex(); - pid = tree.GetNext(pid, feat.Fvalue(split_index), feat.IsMissing(split_index)); + pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index)); gstats[pid].Add(gpair[ridx]); } } diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 6479e9feea93..b1c6f09eb5b3 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -89,7 +89,7 @@ struct ReadRowFunction { : matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {} __device__ void operator()(size_t col) { - auto value = matrix.GetElement(row, col); + auto value = matrix.GetFvalue(row, col); if (isnan(value)) { value = -1; } diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 920abec8d3bb..5f9603850eb8 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -86,7 +86,7 @@ struct ReadRowFunction { : matrix(std::move(matrix)), row(row), row_data_d(row_data_d) {} __device__ void operator()(size_t col) { - auto value = matrix.GetElement(row, col); + auto value = matrix.GetFvalue(row, col); if (isnan(value)) { value = -1; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index aaad7955ea14..8001076bae52 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -21,7 +21,6 @@ #include #include "../../src/common/common.h" -#include "../../src/common/hist_util.h" #include "../../src/gbm/gbtree_model.h" #if defined(__CUDACC__) #include "../../src/data/ellpack_page.cuh" diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 48c68335b846..665583df17ba 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -11,11 +11,12 @@ #include "gtest/gtest.h" #include "../helpers.h" #include "../../../src/gbm/gbtree_model.h" +#include "test_predictor.h" namespace xgboost { namespace predictor { -TEST(GpuPredictor, Basic) { +TEST(GPUPredictor, Basic) { auto cpu_lparam = CreateEmptyGenericParam(-1); auto gpu_lparam = CreateEmptyGenericParam(0); @@ -56,7 +57,20 @@ TEST(GpuPredictor, Basic) { } } -TEST(gpu_predictor, ExternalMemoryTest) { +TEST(GPUPredictor, EllpackBasic) { + for (size_t bins = 2; bins < 258; bins += 16) { + size_t rows = bins * 16; + TestPredictionFromGradientIndex("gpu_predictor", rows, bins); + TestPredictionFromGradientIndex("gpu_predictor", bins, bins); + } +} + +TEST(GPUPredictor, EllpackTraining) { + size_t constexpr kRows { 128 }; + TestTrainingPrediction(kRows, "gpu_hist"); +} + +TEST(GPUPredictor, ExternalMemoryTest) { auto lparam = CreateEmptyGenericParam(0); std::unique_ptr gpu_predictor = std::unique_ptr(Predictor::Create("gpu_predictor", &lparam)); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 927697e6bda2..3000d2fd448f 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -2,13 +2,16 @@ * Copyright 2020 by Contributors */ -#include #include #include #include +#include +#include + +#include "test_predictor.h" #include "../helpers.h" -#include "xgboost/generic_parameters.h" +#include "../../../src/common/io.h" namespace xgboost { TEST(Predictor, PredictionCache) { @@ -30,4 +33,52 @@ TEST(Predictor, PredictionCache) { add_cache(); EXPECT_ANY_THROW(container.Entry(m)); } + +// Only run this test when CUDA is enabled. +void TestTrainingPrediction(size_t rows, std::string tree_method) { + size_t constexpr kCols = 16; + size_t constexpr kClasses = 3; + size_t constexpr kIters = 3; + + std::unique_ptr learner; + auto train = [&](std::string predictor, HostDeviceVector* out) { + auto pp_m = CreateDMatrix(rows, kCols, 0); + auto p_m = *pp_m; + + auto &h_label = p_m->Info().labels_.HostVector(); + h_label.resize(rows); + + for (size_t i = 0; i < rows; ++i) { + h_label[i] = i % kClasses; + } + + learner.reset(Learner::Create({})); + learner->SetParam("tree_method", tree_method); + learner->SetParam("objective", "multi:softprob"); + learner->SetParam("predictor", predictor); + learner->SetParam("num_feature", std::to_string(kCols)); + learner->SetParam("num_class", std::to_string(kClasses)); + learner->Configure(); + + for (size_t i = 0; i < kIters; ++i) { + learner->UpdateOneIter(i, p_m); + } + learner->Predict(p_m, false, out); + delete pp_m; + }; + // Alternate the predictor, CPU predictor can not use ellpack while GPU predictor can + // not use CPU histogram index. So it's guaranteed one of the following is not + // predicting from histogram index. Note: As of writing only GPU supports predicting + // from gradient index, the test is written for future portability. + HostDeviceVector predictions_0; + train("cpu_predictor", &predictions_0); + + HostDeviceVector predictions_1; + train("gpu_predictor", &predictions_1); + + for (size_t i = 0; i < rows; ++i) { + EXPECT_NEAR(predictions_1.ConstHostVector()[i], + predictions_0.ConstHostVector()[i], kRtEps); + } +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h new file mode 100644 index 000000000000..e00c61641650 --- /dev/null +++ b/tests/cpp/predictor/test_predictor.h @@ -0,0 +1,70 @@ +#ifndef XGBOOST_TEST_PREDICTOR_H_ +#define XGBOOST_TEST_PREDICTOR_H_ + +#include +#include +#include +#include "../helpers.h" + +namespace xgboost { +template +void TestPredictionFromGradientIndex(std::string name, size_t rows, int32_t bins) { + constexpr size_t kCols { 8 }, kClasses { 3 }; + + LearnerModelParam param; + param.num_feature = kCols; + param.num_output_group = kClasses; + param.base_score = 0.5; + + auto lparam = CreateEmptyGenericParam(0); + + std::unique_ptr predictor = + std::unique_ptr(Predictor::Create(name, &lparam)); + predictor->Configure({}); + + gbm::GBTreeModel model = CreateTestModel(¶m, kClasses); + + { + auto pp_ellpack = CreateDMatrix(rows, kCols, 0); + auto p_ellpack = *pp_ellpack; + // Use same number of bins as rows. + for (auto const &page DMLC_ATTRIBUTE_UNUSED : + p_ellpack->GetBatches({0, static_cast(bins), 0})) { + } + + auto pp_precise = CreateDMatrix(rows, kCols, 0); + auto p_precise = *pp_precise; + + PredictionCacheEntry approx_out_predictions; + predictor->PredictBatch(p_ellpack.get(), &approx_out_predictions, model, 0); + + PredictionCacheEntry precise_out_predictions; + predictor->PredictBatch(p_precise.get(), &precise_out_predictions, model, 0); + + for (size_t i = 0; i < rows; ++i) { + CHECK_EQ(approx_out_predictions.predictions.HostVector()[i], + precise_out_predictions.predictions.HostVector()[i]); + } + + delete pp_precise; + delete pp_ellpack; + } + + { + // Predictor should never try to create the histogram index by itself. As only + // histogram index from training data is valid and predictor doesn't known which + // matrix is used for training. + auto pp_dmat = CreateDMatrix(rows, kCols, 0); + auto p_dmat = *pp_dmat; + PredictionCacheEntry precise_out_predictions; + predictor->PredictBatch(p_dmat.get(), &precise_out_predictions, model, 0); + ASSERT_FALSE(p_dmat->PageExists()); + delete pp_dmat; + } +} + +void TestTrainingPrediction(size_t rows, std::string tree_method); + +} // namespace xgboost + +#endif // XGBOOST_TEST_PREDICTOR_H_ diff --git a/tests/python/regression_test_utilities.py b/tests/python/regression_test_utilities.py index 1a3b80690408..f6abc8732bd1 100644 --- a/tests/python/regression_test_utilities.py +++ b/tests/python/regression_test_utilities.py @@ -25,6 +25,19 @@ def __init__(self, name, get_dataset, objective, metric, self.w = None self.use_external_memory = use_external_memory + def __str__(self): + a = 'name: {name}\nobjective:{objective}, metric:{metric}, '.format( + name=self.name, + objective=self.objective, + metric=self.metric) + b = 'external memory:{use_external_memory}\n'.format( + use_external_memory=self.use_external_memory + ) + return a + b + + def __repr__(self): + return self.__str__() + def get_boston(): data = datasets.load_boston()