Skip to content

Commit

Permalink
[LEARNER] refactor learner
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 16, 2016
1 parent 4b4b36d commit 0d95e86
Show file tree
Hide file tree
Showing 14 changed files with 462 additions and 509 deletions.
14 changes: 14 additions & 0 deletions include/xgboost/data.h
Expand Up @@ -14,6 +14,9 @@
#include "./base.h"

namespace xgboost {
// forward declare learner.
class LearnerImpl;

/*! \brief data type accepted by xgboost interface */
enum DataType {
kFloat32 = 1,
Expand Down Expand Up @@ -199,6 +202,8 @@ class DataSource : public dmlc::DataIter<RowBatch> {
*/
class DMatrix {
public:
/*! \brief default constructor */
DMatrix() : cache_learner_ptr_(nullptr) {}
/*! \brief meta information of the dataset */
virtual MetaInfo& info() = 0;
/*! \brief meta information of the dataset */
Expand All @@ -222,13 +227,16 @@ class DMatrix {
* \param subsample subsample ratio when generating column access.
* \param max_row_perbatch auxilary information, maximum row used in each column batch.
* this is a hint information that can be ignored by the implementation.
* \return Number of column blocks in the column access.
*/
virtual void InitColAccess(const std::vector<bool>& enabled,
float subsample,
size_t max_row_perbatch) = 0;
// the following are column meta data, should be able to answer them fast.
/*! \return whether column access is enabled */
virtual bool HaveColAccess() const = 0;
/*! \return Whether the data columns single column block. */
virtual bool SingleColBlock() const = 0;
/*! \brief get number of non-missing entries in column */
virtual size_t GetColSize(size_t cidx) const = 0;
/*! \brief get column density */
Expand Down Expand Up @@ -279,6 +287,12 @@ class DMatrix {
*/
static DMatrix* Create(dmlc::Parser<uint32_t>* parser,
const char* cache_prefix = nullptr);

private:
// allow learner class to access this field.
friend class LearnerImpl;
/*! \brief public field to back ref cached matrix. */
LearnerImpl* cache_learner_ptr_;
};

} // namespace xgboost
Expand Down
17 changes: 16 additions & 1 deletion include/xgboost/gbm.h
Expand Up @@ -25,6 +25,14 @@ class GradientBooster {
public:
/*! \brief virtual destructor */
virtual ~GradientBooster() {}
/*!
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
Expand Down Expand Up @@ -123,9 +131,16 @@ class GradientBooster {
* \breif create a gradient booster from given name
* \param name name of gradient booster
*/
static GradientBooster* Create(const char *name);
static GradientBooster* Create(const std::string& name);
};

// implementing configure.
template<typename PairIter>
inline void GradientBooster::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}

/*!
* \brief Registry entry for tree updater.
*/
Expand Down
42 changes: 29 additions & 13 deletions include/xgboost/learner.h
Expand Up @@ -14,7 +14,7 @@
#include <vector>
#include "./base.h"
#include "./gbm.h"
#include "./meric.h"
#include "./metric.h"
#include "./objective.h"

namespace xgboost {
Expand All @@ -36,6 +36,14 @@ namespace xgboost {
*/
class Learner : public rabit::Serializable {
public:
/*!
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
Expand All @@ -59,27 +67,27 @@ class Learner : public rabit::Serializable {
* \param iter current iteration number
* \param train reference to the data matrix.
*/
void UpdateOneIter(int iter, DMatrix* train);
virtual void UpdateOneIter(int iter, DMatrix* train) = 0;
/*!
* \brief Do customized gradient boosting with in_gpair.
* in_gair can be mutated after this call.
* \param iter current iteration number
* \param train reference to the data matrix.
* \param in_gpair The input gradient statistics.
*/
void BoostOneIter(int iter,
DMatrix* train,
std::vector<bst_gpair>* in_gpair);
virtual void BoostOneIter(int iter,
DMatrix* train,
std::vector<bst_gpair>* in_gpair) = 0;
/*!
* \brief evaluate the model for specific iteration using the configured metrics.
* \param iter iteration number
* \param data_sets datasets to be evaluated.
* \param data_names name of each dataset
* \return a string corresponding to the evaluation result
*/
std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names);
virtual std::string EvalOneIter(int iter,
const std::vector<DMatrix*>& data_sets,
const std::vector<std::string>& data_names) = 0;
/*!
* \brief get prediction given the model.
* \param data input data
Expand All @@ -89,11 +97,11 @@ class Learner : public rabit::Serializable {
* predictor, when it equals 0, this means we are using all the trees
* \param pred_leaf whether to only predict the leaf index of each tree in a boosted tree predictor
*/
void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const;
virtual void Predict(DMatrix* data,
bool output_margin,
std::vector<float> *out_preds,
unsigned ntree_limit = 0,
bool pred_leaf = false) const = 0;
/*!
* \return whether the model allow lazy checkpoint in rabit.
*/
Expand Down Expand Up @@ -151,5 +159,13 @@ inline void Learner::Predict(const SparseBatch::Inst& inst,
obj_->PredTransform(out_preds);
}
}

// implementing configure.
template<typename PairIter>
inline void Learner::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}

} // namespace xgboost
#endif // XGBOOST_LEARNER_H_
3 changes: 2 additions & 1 deletion include/xgboost/metric.h
Expand Up @@ -9,6 +9,7 @@

#include <dmlc/registry.h>
#include <vector>
#include <string>
#include <functional>
#include "./data.h"
#include "./base.h"
Expand Down Expand Up @@ -42,7 +43,7 @@ class Metric {
* and the name will be matched in the registry.
* \return the created metric.
*/
static Metric* Create(const char *name);
static Metric* Create(const std::string& name);
};

/*!
Expand Down
21 changes: 18 additions & 3 deletions include/xgboost/objective.h
Expand Up @@ -22,10 +22,18 @@ class ObjFunction {
/*! \brief virtual destructor */
virtual ~ObjFunction() {}
/*!
* \brief Initialize the objective with the specified parameters.
* \brief set configuration from pair iterators.
* \param begin The beginning iterator.
* \param end The end iterator.
* \tparam PairIter iterator<std::pair<std::string, std::string> >
*/
template<typename PairIter>
inline void Configure(PairIter begin, PairIter end);
/*!
* \brief Configure the objective with the specified parameters.
* \param args arguments to the objective function.
*/
virtual void Init(const std::vector<std::pair<std::string, std::string> >& args) = 0;
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& args) = 0;
/*!
* \brief Get gradient over each of predictions, given existing information.
* \param preds prediction of current round
Expand Down Expand Up @@ -66,9 +74,16 @@ class ObjFunction {
* \brief Create an objective function according to name.
* \param name Name of the objective.
*/
static ObjFunction* Create(const char* name);
static ObjFunction* Create(const std::string& name);
};

// implementing configure.
template<typename PairIter>
inline void ObjFunction::Configure(PairIter begin, PairIter end) {
std::vector<std::pair<std::string, std::string> > vec(begin, end);
this->Configure(vec);
}

/*!
* \brief Registry entry for objective factory functions.
*/
Expand Down
2 changes: 1 addition & 1 deletion include/xgboost/tree_updater.h
Expand Up @@ -54,7 +54,7 @@ class TreeUpdater {
* \brief Create a tree updater given name
* \param name Name of the tree updater.
*/
static TreeUpdater* Create(const char* name);
static TreeUpdater* Create(const std::string& name);
};

/*!
Expand Down
55 changes: 55 additions & 0 deletions src/common/io.h
Expand Up @@ -9,12 +9,67 @@
#define XGBOOST_COMMON_IO_H_

#include <dmlc/io.h>
#include <string>
#include <cstring>
#include "./sync.h"

namespace xgboost {
namespace common {
typedef rabit::utils::MemoryFixSizeBuffer MemoryFixSizeBuffer;
typedef rabit::utils::MemoryBufferStream MemoryBufferStream;

/*!
* \brief Input stream that support additional PeekRead
* operation, besides read.
*/
class PeekableInStream : public dmlc::Stream {
public:
explicit PeekableInStream(dmlc::Stream* strm)
: strm_(strm), buffer_ptr_(0) {}

size_t Read(void* dptr, size_t size) override {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer == 0) return strm_->Read(dptr, size);
if (nbuffer < size) {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, nbuffer);
buffer_ptr_ += nbuffer;
return nbuffer + strm_->Read(reinterpret_cast<char*>(dptr) + nbuffer,
size - nbuffer);
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
buffer_ptr_ += size;
return size;
}
}

size_t PeekRead(void* dptr, size_t size) {
size_t nbuffer = buffer_.length() - buffer_ptr_;
if (nbuffer < size) {
buffer_ = buffer_.substr(buffer_ptr_, buffer_.length());
buffer_ptr_ = 0;
buffer_.resize(size);
size_t nadd = strm_->Read(dmlc::BeginPtr(buffer_) + nbuffer, size - nbuffer);
buffer_.resize(nbuffer + nadd);
std::memcpy(dptr, dmlc::BeginPtr(buffer_), buffer_.length());
return buffer_.length();
} else {
std::memcpy(dptr, dmlc::BeginPtr(buffer_) + buffer_ptr_, size);
return size;
}
}

void Write(const void* dptr, size_t size) override {
LOG(FATAL) << "Not implemented";
}

private:
/*! \brief input stream */
dmlc::Stream *strm_;
/*! \brief current buffer pointer */
size_t buffer_ptr_;
/*! \brief internal buffer */
std::string buffer_;
};
} // namespace common
} // namespace xgboost
#endif // XGBOOST_COMMON_IO_H_
53 changes: 0 additions & 53 deletions src/common/metric_set.h

This file was deleted.

2 changes: 2 additions & 0 deletions src/gbm/gbtree.cc
Expand Up @@ -39,6 +39,8 @@ struct GBTreeTrainParam : public dmlc::Parameter<GBTreeTrainParam> {
" This option is used to support boosted random forest");
DMLC_DECLARE_FIELD(updater_seq).set_default("grow_colmaker,prune")
.describe("Tree updater sequence.");
// add alias
DMLC_DECLARE_ALIAS(updater_seq, updater);
}
};

Expand Down

0 comments on commit 0d95e86

Please sign in to comment.