Skip to content

Commit

Permalink
Enhance inplace prediction.
Browse files Browse the repository at this point in the history
* Accept array interface for csr and array.
* Accept an optional proxy dmatrix for metainfo.

This constructs an explicit `_ProxyDMatrix` type in Python.
  • Loading branch information
trivialfis committed Feb 1, 2021
1 parent 0ad6e18 commit 9473f5a
Show file tree
Hide file tree
Showing 22 changed files with 897 additions and 475 deletions.
4 changes: 2 additions & 2 deletions include/xgboost/gbm.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2014-2020 by Contributors
* Copyright 2014-2021 by Contributors
* \file gbm.h
* \brief Interface of gradient booster,
* that learns through gradient statistics.
Expand Down Expand Up @@ -118,7 +118,7 @@ class GradientBooster : public Model, public Configurable {
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &, float,
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
PredictionCacheEntry*,
uint32_t,
uint32_t) const {
Expand Down
1 change: 1 addition & 0 deletions include/xgboost/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ struct StringView {
public:
StringView() = default;
StringView(CharT const* str, size_t size) : str_{str}, size_{size} {}
explicit StringView(std::string const& str): str_{str.c_str()}, size_{str.size()} {}
explicit StringView(CharT const* str) : str_{str}, size_{Traits::length(str)} {}

CharT const& operator[](size_t p) const { return str_[p]; }
Expand Down
25 changes: 21 additions & 4 deletions include/xgboost/learner.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2020 by Contributors
* Copyright 2015-2021 by Contributors
* \file learner.h
* \brief Learner interface that integrates objective, gbm and evaluation together.
* This is the user facing XGBoost training module.
Expand Down Expand Up @@ -30,6 +30,15 @@ class ObjFunction;
class DMatrix;
class Json;

enum class PredictionType : std::uint8_t { // NOLINT
kValue = 0,
kMargin = 1,
kContribution = 2,
kApproxContribution = 3,
kInteraction = 4,
kLeaf = 5
};

/*! \brief entry to to easily hold returning information */
struct XGBAPIThreadLocalEntry {
/*! \brief result holder for returning string */
Expand All @@ -42,7 +51,10 @@ struct XGBAPIThreadLocalEntry {
std::vector<bst_float> ret_vec_float;
/*! \brief temp variable of gradient pairs. */
std::vector<GradientPair> tmp_gpair;
/*! \brief Temp variable for returing prediction result. */
PredictionCacheEntry prediction_entry;
/*! \brief Temp variable for returing prediction shape. */
std::vector<bst_ulong> prediction_shape;
};

/*!
Expand Down Expand Up @@ -123,13 +135,17 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Inplace prediction.
*
* \param x A type erased data adapter.
* \param p_m An optional Proxy DMatrix object storing meta info like
* base margin. Can be nullptr.
* \param type Prediction type.
* \param missing Missing value in the data.
* \param [in,out] out_preds Pointer to output prediction vector.
* \param layer_begin (Optional) Begining of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const& x, std::string const& type,
virtual void InplacePredict(dmlc::any const &x,
std::shared_ptr<DMatrix> p_m,
PredictionType type,
float missing,
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;
Expand All @@ -138,6 +154,7 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \brief Get number of boosted rounds from gradient booster.
*/
virtual int32_t BoostedRounds() const = 0;
virtual uint32_t Groups() const = 0;

void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;
Expand Down
12 changes: 8 additions & 4 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*!
* Copyright 2017-2020 by Contributors
* Copyright 2017-2021 by Contributors
* \file predictor.h
* \brief Interface of predictor,
* performs predictions for a gradient booster.
Expand Down Expand Up @@ -142,10 +142,14 @@ class Predictor {
* \param [in,out] out_preds The output preds.
* \param tree_begin (Optional) Begining of boosted trees used for prediction.
* \param tree_end (Optional) End of booster trees. 0 means do not limit trees.
*
* \return True if the data can be handled by current predictor, false otherwise.
*/
virtual void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
float missing, PredictionCacheEntry *out_preds,
uint32_t tree_begin = 0, uint32_t tree_end = 0) const = 0;
virtual bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds,
uint32_t tree_begin = 0,
uint32_t tree_end = 0) const = 0;
/**
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is
Expand Down

0 comments on commit 9473f5a

Please sign in to comment.