Skip to content

Commit

Permalink
Predict on Ellpack. (#5327)
Browse files Browse the repository at this point in the history
* Unify GPU prediction node.
* Add `PageExists`.
* Dispatch prediction on input data for GPU Predictor.
  • Loading branch information
trivialfis committed Feb 22, 2020
1 parent 70a91ec commit 655cf17
Show file tree
Hide file tree
Showing 19 changed files with 320 additions and 134 deletions.
27 changes: 25 additions & 2 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,19 @@ struct BatchParam {
/*! \brief The GPU device to use. */
int gpu_id;
/*! \brief Maximum number of bins per feature for histograms. */
int max_bin;
int max_bin { 0 };
/*! \brief Number of rows in a GPU batch, used for finding quantiles on GPU. */
int gpu_batch_nrows;
/*! \brief Page size for external memory mode. */
size_t gpu_page_size;

BatchParam() = default;
BatchParam(int32_t device, int32_t max_bin, int32_t gpu_batch_nrows,
size_t gpu_page_size = 0) :
gpu_id{device},
max_bin{max_bin},
gpu_batch_nrows{gpu_batch_nrows},
gpu_page_size{gpu_page_size}
{}
inline bool operator!=(const BatchParam& other) const {
return gpu_id != other.gpu_id ||
max_bin != other.max_bin ||
Expand Down Expand Up @@ -438,6 +445,9 @@ class DMatrix {
*/
template<typename T>
BatchSet<T> GetBatches(const BatchParam& param = {});
template <typename T>
bool PageExists() const;

// the following are column meta data, should be able to answer them fast.
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
Expand Down Expand Up @@ -493,13 +503,26 @@ class DMatrix {
virtual BatchSet<CSCPage> GetColumnBatches() = 0;
virtual BatchSet<SortedCSCPage> GetSortedColumnBatches() = 0;
virtual BatchSet<EllpackPage> GetEllpackBatches(const BatchParam& param) = 0;

virtual bool EllpackExists() const = 0;
virtual bool SparsePageExists() const = 0;
};

template<>
inline BatchSet<SparsePage> DMatrix::GetBatches(const BatchParam&) {
return GetRowBatches();
}

template<>
inline bool DMatrix::PageExists<EllpackPage>() const {
return this->EllpackExists();
}

template<>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
}

template<>
inline BatchSet<CSCPage> DMatrix::GetBatches(const BatchParam&) {
return GetColumnBatches();
Expand Down
8 changes: 4 additions & 4 deletions include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class RegTree : public Model {
/*! \brief tree node */
class Node {
public:
Node() {
XGBOOST_DEVICE Node() {
// assert compact alignment
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
"Node: 64 bit align");
Expand Down Expand Up @@ -422,7 +422,7 @@ class RegTree : public Model {
* \param i feature index.
* \return the i-th feature value
*/
bst_float Fvalue(size_t i) const;
bst_float GetFvalue(size_t i) const;
/*!
* \brief check whether i-th entry is missing
* \param i feature index.
Expand Down Expand Up @@ -565,7 +565,7 @@ inline size_t RegTree::FVec::Size() const {
return data_.size();
}

inline bst_float RegTree::FVec::Fvalue(size_t i) const {
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i].fvalue;
}

Expand All @@ -577,7 +577,7 @@ inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const {
bst_node_t nid = 0;
while (!(*this)[nid].IsLeaf()) {
unsigned split_index = (*this)[nid].SplitIndex();
nid = this->GetNext(nid, feat.Fvalue(split_index), feat.IsMissing(split_index));
nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index));
}
return nid;
}
Expand Down
4 changes: 2 additions & 2 deletions src/data/ellpack_page.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ __global__ void CompressBinEllpackKernel(
common::CompressedByteT* __restrict__ buffer, // gidx_buffer
const size_t* __restrict__ row_ptrs, // row offset of input data
const Entry* __restrict__ entries, // One batch of input data
const float* __restrict__ cuts, // HistogramCuts::cut
const uint32_t* __restrict__ cut_rows, // HistogramCuts::row_ptrs
const float* __restrict__ cuts, // HistogramCuts::cut_values_
const uint32_t* __restrict__ cut_rows, // HistogramCuts::cut_ptrs_
size_t base_row, // batch_row_begin
size_t n_rows,
size_t row_stride,
Expand Down
9 changes: 8 additions & 1 deletion src/data/ellpack_page.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ struct EllpackInfo {
size_t NumSymbols() const {
return n_bins + 1;
}
size_t NumFeatures() const {
return min_fvalue.size();
}
};

/** \brief Struct for accessing and manipulating an ellpack matrix on the
Expand All @@ -89,7 +92,7 @@ struct EllpackMatrix {

// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
__device__ bst_float GetElement(size_t ridx, size_t fidx) const {
__device__ int32_t GetBinIndex(size_t ridx, size_t fidx) const {
ridx -= base_rowid;
auto row_begin = info.row_stride * ridx;
auto row_end = row_begin + info.row_stride;
Expand All @@ -103,6 +106,10 @@ struct EllpackMatrix {
info.feature_segments[fidx],
info.feature_segments[fidx + 1]);
}
return gidx;
}
__device__ bst_float GetFvalue(size_t ridx, size_t fidx) const {
auto gidx = GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
}
Expand Down
10 changes: 7 additions & 3 deletions src/data/simple_dmatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@ BatchSet<SortedCSCPage> SimpleDMatrix::GetSortedColumnBatches() {
}

BatchSet<EllpackPage> SimpleDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// ELLPACK page doesn't exist, generate it
if (!ellpack_page_) {
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
if (!ellpack_page_ || (batch_param_ != param && param != BatchParam{})) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
ellpack_page_.reset(new EllpackPage(this, param));
batch_param_ = param;
}
auto begin_iter =
BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_page_.get()));
Expand Down
8 changes: 8 additions & 0 deletions src/data/simple_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ class SimpleDMatrix : public DMatrix {
std::unique_ptr<CSCPage> column_page_;
std::unique_ptr<SortedCSCPage> sorted_column_page_;
std::unique_ptr<EllpackPage> ellpack_page_;
BatchParam batch_param_;

bool EllpackExists() const override {
return static_cast<bool>(ellpack_page_);
}
bool SparsePageExists() const override {
return true;
}
};
} // namespace data
} // namespace xgboost
Expand Down
4 changes: 2 additions & 2 deletions src/data/sparse_page_dmatrix.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014 by Contributors
* Copyright 2014-2020 by Contributors
* \file sparse_page_dmatrix.cc
* \brief The external memory version of Page Iterator.
* \author Tianqi Chen
Expand Down Expand Up @@ -47,7 +47,7 @@ BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& par
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
// Lazily instantiate
if (!ellpack_source_ || batch_param_ != param) {
if (!ellpack_source_ || (batch_param_ != param && param != BatchParam{})) {
ellpack_source_.reset(new EllpackPageSource(this, cache_info_, param));
batch_param_ = param;
}
Expand Down
7 changes: 7 additions & 0 deletions src/data/sparse_page_dmatrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class SparsePageDMatrix : public DMatrix {
std::string cache_info_;
// Store column densities to avoid recalculating
std::vector<float> col_density_;

bool EllpackExists() const override {
return static_cast<bool>(ellpack_source_);
}
bool SparsePageExists() const override {
return static_cast<bool>(row_source_);
}
};
} // namespace data
} // namespace xgboost
Expand Down
Loading

0 comments on commit 655cf17

Please sign in to comment.