Skip to content

Commit

Permalink
[GBM] remove need to explicit InitModel, rename save/load
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Jan 16, 2016
1 parent 82ceb4d commit 4b4b36d
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 59 deletions.
9 changes: 2 additions & 7 deletions include/xgboost/gbm.h
Expand Up @@ -32,21 +32,16 @@ class GradientBooster {
* \param cfg configurations on both training and model parameters.
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
/*!
* \brief Initialize the model.
* User need to call Configure before calling InitModel.
*/
virtual void InitModel() = 0;
/*!
* \brief load model from stream
* \param fi input stream.
*/
virtual void LoadModel(dmlc::Stream* fi) = 0;
virtual void Load(dmlc::Stream* fi) = 0;
/*!
* \brief save model to stream.
* \param fo output stream
*/
virtual void SaveModel(dmlc::Stream* fo) const = 0;
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief reset the predict buffer size.
* This will invalidate all the previous cached results
Expand Down
4 changes: 2 additions & 2 deletions include/xgboost/tree_model.h
Expand Up @@ -304,7 +304,7 @@ class TreeModel {
* \brief load model from stream
* \param fi input stream
*/
inline void LoadModel(dmlc::Stream* fi) {
inline void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(TreeParam)), sizeof(TreeParam));
nodes.resize(param.num_nodes);
stats.resize(param.num_nodes);
Expand All @@ -327,7 +327,7 @@ class TreeModel {
* \brief save model to stream
* \param fo output stream
*/
inline void SaveModel(dmlc::Stream* fo) const {
inline void Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats.size()));
fo->Write(&param, sizeof(TreeParam));
Expand Down
20 changes: 11 additions & 9 deletions src/gbm/gblinear.cc
Expand Up @@ -90,18 +90,20 @@ class GBLinear : public GradientBooster {
}
param.InitAllowUnknown(cfg);
}
void LoadModel(dmlc::Stream* fi) override {
model.LoadModel(fi);
void Load(dmlc::Stream* fi) override {
model.Load(fi);
}
void SaveModel(dmlc::Stream* fo) const override {
model.SaveModel(fo);
}
void InitModel() override {
model.InitModel();
void Save(dmlc::Stream* fo) const override {
model.Save(fo);
}
virtual void DoBoost(DMatrix *p_fmat,
int64_t buffer_offset,
std::vector<bst_gpair> *in_gpair) {
// lazily initialize the model when not ready.
if (model.weight.size() == 0) {
model.InitModel();
}

std::vector<bst_gpair> &gpair = *in_gpair;
const int ngroup = model.param.num_output_group;
const std::vector<bst_uint> &rowset = p_fmat->buffered_rowset();
Expand Down Expand Up @@ -248,12 +250,12 @@ class GBLinear : public GradientBooster {
std::fill(weight.begin(), weight.end(), 0.0f);
}
// save the model to file
inline void SaveModel(dmlc::Stream* fo) const {
inline void Save(dmlc::Stream* fo) const {
fo->Write(&param, sizeof(param));
fo->Write(weight);
}
// load model from file
inline void LoadModel(dmlc::Stream* fi) {
inline void Load(dmlc::Stream* fi) {
CHECK_EQ(fi->Read(&param, sizeof(param)), sizeof(param));
fi->Read(&weight);
}
Expand Down
72 changes: 33 additions & 39 deletions src/gbm/gbtree.cc
Expand Up @@ -52,8 +52,8 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
int num_feature;
/*! \brief pad this space, for backward compatiblity reason.*/
int pad_32bit;
/*! \brief size of prediction buffer allocated used for buffering */
int64_t num_pbuffer;
/*! \brief deprecated padding space. */
int64_t num_pbuffer_deprecated;
/*!
* \brief how many output group a single instance can produce
* this affects the behavior of number of output we have:
Expand Down Expand Up @@ -82,24 +82,13 @@ struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
.describe("Reserved option for vector tree.");
}
/*! \return size of prediction buffer actually needed */
inline size_t PredBufferSize() const {
return num_output_group * num_pbuffer * (size_leaf_vector + 1);
}
/*!
* \brief get the buffer offset given a buffer index and group id
* \return calculated buffer offset
*/
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
if (buffer_index < 0) return -1;
CHECK_LT(buffer_index, num_pbuffer);
return (buffer_index + num_pbuffer * bst_group) * (size_leaf_vector + 1);
}
};

// gradient boosted trees
class GBTree : public GradientBooster {
public:
GBTree() : num_pbuffer(0) {}

void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) override {
this->cfg = cfg;
// initialize model parameters if not yet been initialized.
Expand All @@ -118,52 +107,41 @@ class GBTree : public GradientBooster {
}
}

void LoadModel(dmlc::Stream* fi) override {
void Load(dmlc::Stream* fi) override {
CHECK_EQ(fi->Read(&mparam, sizeof(mparam)), sizeof(mparam))
<< "GBTree: invalid model file";
trees.clear();
for (int i = 0; i < mparam.num_trees; ++i) {
std::unique_ptr<RegTree> ptr(new RegTree());
ptr->LoadModel(fi);
ptr->Load(fi);
trees.push_back(std::move(ptr));
}
tree_info.resize(mparam.num_trees);
if (mparam.num_trees != 0) {
CHECK_EQ(fi->Read(dmlc::BeginPtr(tree_info), sizeof(int) * mparam.num_trees),
sizeof(int) * mparam.num_trees);
}
this->ResetPredBuffer(0);
// clear the predict buffer.
this->ResetPredBuffer(num_pbuffer);
}

void SaveModel(dmlc::Stream* fo) const override {
void Save(dmlc::Stream* fo) const override {
CHECK_EQ(mparam.num_trees, static_cast<int>(trees.size()));
// not save predict buffer.
GBTreeModelParam p = mparam;
p.num_pbuffer = 0;
fo->Write(&p, sizeof(p));
fo->Write(&mparam, sizeof(mparam));
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(fo);
trees[i]->Save(fo);
}
if (tree_info.size() != 0) {
fo->Write(dmlc::BeginPtr(tree_info), sizeof(int) * tree_info.size());
}
}

void InitModel() override {
CHECK(mparam.num_trees == 0 && trees.size() == 0)
<< "Model has already been initialized.";
pred_buffer.clear();
pred_counter.clear();
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0);
}

void ResetPredBuffer(size_t num_pbuffer) override {
mparam.num_pbuffer = static_cast<int64_t>(num_pbuffer);
this->num_pbuffer = num_pbuffer;
pred_buffer.clear();
pred_counter.clear();
pred_buffer.resize(mparam.PredBufferSize(), 0.0f);
pred_counter.resize(mparam.PredBufferSize(), 0);
pred_buffer.resize(this->PredBufferSize(), 0.0f);
pred_counter.resize(this->PredBufferSize(), 0);
}

bool AllowLazyCheckPoint() const override {
Expand Down Expand Up @@ -348,7 +326,7 @@ class GBTree : public GradientBooster {
#pragma omp parallel for schedule(static)
for (bst_omp_uint i = 0; i < ndata; ++i) {
const bst_uint ridx = rowset[i];
const int64_t bid = mparam.BufferOffset(buffer_offset + ridx, bst_group);
const int64_t bid = this->BufferOffset(buffer_offset + ridx, bst_group);
const int tid = leaf_position[ridx];
CHECK_EQ(pred_counter[bid], trees.size());
CHECK_GE(tid, 0);
Expand All @@ -372,7 +350,7 @@ class GBTree : public GradientBooster {
float psum = 0.0f;
// sum of leaf vector
std::vector<float> vec_psum(mparam.size_leaf_vector, 0.0f);
const int64_t bid = mparam.BufferOffset(buffer_index, bst_group);
const int64_t bid = this->BufferOffset(buffer_index, bst_group);
// number of valid trees
unsigned treeleft = ntree_limit == 0 ? std::numeric_limits<unsigned>::max() : ntree_limit;
// load buffered results if any
Expand Down Expand Up @@ -452,6 +430,20 @@ class GBTree : public GradientBooster {
}
}
}
/*! \return size of prediction buffer actually needed */
inline size_t PredBufferSize() const {
return mparam.num_output_group * num_pbuffer * (mparam.size_leaf_vector + 1);
}
/*!
* \brief get the buffer offset given a buffer index and group id
* \return calculated buffer offset
*/
inline int64_t BufferOffset(int64_t buffer_index, int bst_group) const {
if (buffer_index < 0) return -1;
size_t bidx = static_cast<size_t>(buffer_index);
CHECK_LT(bidx, num_pbuffer);
return (bidx + num_pbuffer * bst_group) * (mparam.size_leaf_vector + 1);
}

// --- data structure ---
// training parameter
Expand All @@ -462,8 +454,10 @@ class GBTree : public GradientBooster {
std::vector<std::unique_ptr<RegTree> > trees;
/*! \brief some information indicator of the tree, reserved */
std::vector<int> tree_info;
/*! \brief predict buffer size */
size_t num_pbuffer;
/*! \brief prediction buffer */
std::vector<float> pred_buffer;
std::vector<float> pred_buffer;
/*! \brief prediction buffer counter, remember the prediction */
std::vector<unsigned> pred_counter;
// ----training fields----
Expand Down
4 changes: 2 additions & 2 deletions src/tree/updater_sync.cc
Expand Up @@ -29,13 +29,13 @@ class TreeSyncher: public TreeUpdater {
int rank = rabit::GetRank();
if (rank == 0) {
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->SaveModel(&fs);
trees[i]->Save(&fs);
}
}
fs.Seek(0);
rabit::Broadcast(&s_model, 0);
for (size_t i = 0; i < trees.size(); ++i) {
trees[i]->LoadModel(&fs);
trees[i]->Load(&fs);
}
}
};
Expand Down

0 comments on commit 4b4b36d

Please sign in to comment.