Skip to content

Commit

Permalink
Port changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jan 14, 2022
1 parent 1b3cbc1 commit ad474ed
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 41 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ struct TreeParam : public dmlc::Parameter<TreeParam> {
/*! \brief maximum depth, this is a statistics of the tree */
int deprecated_max_depth;
/*! \brief number of features used for tree construction */
int num_feature;
bst_feature_t num_feature;
/*!
* \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree
Expand Down
20 changes: 8 additions & 12 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -903,7 +903,7 @@ XGB_DLL int XGBoosterLoadModel(BoosterHandle handle, const char* fname) {
};
if (common::FileExtension(fname) == "json") {
auto str = read_file();
Json in { Json::Load({str.c_str(), str.size()}) };
Json in{Json::Load(StringView{str})};
static_cast<Learner*>(handle)->LoadModel(in);
} else if (common::FileExtension(fname) == "ubj") {
auto str = read_file();
Expand Down Expand Up @@ -931,25 +931,21 @@ XGB_DLL int XGBoosterSaveModel(BoosterHandle handle, const char *c_fname) {
std::unique_ptr<dmlc::Stream> fo(dmlc::Stream::Create(c_fname, "w"));
auto *learner = static_cast<Learner *>(handle);
learner->Configure();
auto save_json = [&]() {
auto save_json = [&](std::ios::openmode mode) {
Json out{Object()};
learner->SaveModel(&out);
std::string str;
Json::Dump(out, &str);
fo->Write(str.c_str(), str.size());
std::vector<char> str;
Json::Dump(out, &str, mode);
fo->Write(str.data(), str.size());
};
if (common::FileExtension(c_fname) == "json") {
save_json();
save_json(std::ios::out);
} else if (common::FileExtension(c_fname) == "ubj") {
Json out{Object()};
learner->SaveModel(&out);
std::vector<char> str;
Json::Dump(out, &str, std::ios::binary);
fo->Write(str.data(), str.size());
save_json(std::ios::binary);
} else if (XGBOOST_VER_MAJOR == 2 && XGBOOST_VER_MINOR >= 2) {
LOG(WARNING) << "Saving model to JSON as default. You can use file extension `json`, `ubj` or "
"`deprecated` to choose between formats.";
save_json();
save_json(std::ios::out);
} else {
WarnOldModel();
auto *bst = static_cast<Learner *>(handle);
Expand Down
4 changes: 2 additions & 2 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,8 @@ namespace {
StringView ModelMsg() {
return StringView{
R"doc(
If you are loading a serialized model (like pickle in Python) generated by older
XGBoost, please export the model by calling `Booster.save_model` from that version
If you are loading a serialized model (like pickle in Python, RDS in R) generated by
older XGBoost, please export the model by calling `Booster.save_model` from that version
first, then load it back in current version. See:
https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html
Expand Down
71 changes: 46 additions & 25 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1015,11 +1015,12 @@ void RegTree::SaveCategoricalSplit(Json* p_out) const {
out["categories"] = std::move(categories);
}

template <bool typed,
template <bool typed, bool feature_is_64,
typename FloatArrayT = std::conditional_t<typed, F32Array const, Array const>,
typename U8ArrayT = std::conditional_t<typed, U8Array const, Array const>,
typename I32ArrayT = std::conditional_t<typed, I32Array const, Array const>,
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>>
typename I64ArrayT = std::conditional_t<typed, I64Array const, Array const>,
typename IndexArrayT = std::conditional_t<feature_is_64, I64ArrayT, I32ArrayT>>
bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>* p_stats,
std::vector<FeatureType>* p_split_types, std::vector<RegTree::Node>* p_nodes,
std::vector<RegTree::Segment>* p_split_categories_segments) {
Expand All @@ -1045,7 +1046,7 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
CHECK_EQ(rights.size(), n_nodes);
auto const& parents = get<I32ArrayT>(in["parents"]);
CHECK_EQ(parents.size(), n_nodes);
auto const& indices = get<I64ArrayT>(in["split_indices"]);
auto const& indices = get<IndexArrayT>(in["split_indices"]);
CHECK_EQ(indices.size(), n_nodes);
auto const& conds = get<FloatArrayT>(in["split_conditions"]);
CHECK_EQ(conds.size(), n_nodes);
Expand Down Expand Up @@ -1093,12 +1094,19 @@ bool LoadModelImpl(Json const& in, TreeParam* param, std::vector<RTreeNodeStat>*
void RegTree::LoadModel(Json const& in) {
bool has_cat{false};
bool typed = IsA<F32Array>(in["loss_changes"]);
if (typed) {
has_cat = LoadModelImpl<true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
bool feature_is_64 = IsA<I64Array>(in["split_indices"]);
if (typed && feature_is_64) {
has_cat = LoadModelImpl<true, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (typed && !feature_is_64) {
has_cat = LoadModelImpl<true, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else if (!typed && feature_is_64) {
has_cat = LoadModelImpl<false, true>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
} else {
has_cat = LoadModelImpl<false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
has_cat = LoadModelImpl<false, false>(in, &param, &stats_, &split_types_, &nodes_,
&split_categories_segments_);
}

if (has_cat) {
Expand Down Expand Up @@ -1152,27 +1160,40 @@ void RegTree::SaveModel(Json* p_out) const {
I32Array lefts(n_nodes);
I32Array rights(n_nodes);
I32Array parents(n_nodes);
I64Array indices(n_nodes);


F32Array conds(n_nodes);
U8Array default_left(n_nodes);
U8Array split_type(n_nodes);
CHECK_EQ(this->split_types_.size(), param.num_nodes);

for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
loss_changes.Set(i, s.loss_chg);
sum_hessian.Set(i, s.sum_hess);
base_weights.Set(i, s.base_weight);

auto const& n = nodes_[i];
lefts.Set(i, n.LeftChild());
rights.Set(i, n.RightChild());
parents.Set(i, n.Parent());
indices.Set(i, n.SplitIndex());
conds.Set(i, n.SplitCond());
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));

split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
auto save_tree = [&](auto* p_indices_array) {
auto& indices_array = *p_indices_array;
for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
loss_changes.Set(i, s.loss_chg);
sum_hessian.Set(i, s.sum_hess);
base_weights.Set(i, s.base_weight);

auto const& n = nodes_[i];
lefts.Set(i, n.LeftChild());
rights.Set(i, n.RightChild());
parents.Set(i, n.Parent());
indices_array.Set(i, n.SplitIndex());
conds.Set(i, n.SplitCond());
default_left.Set(i, static_cast<uint8_t>(!!n.DefaultLeft()));

split_type.Set(i, static_cast<uint8_t>(this->NodeSplitType(i)));
}
};
if (this->param.num_feature > static_cast<bst_feature_t>(std::numeric_limits<int32_t>::max())) {
I64Array indices_64(n_nodes);
save_tree(&indices_64);
out["split_indices"] = std::move(indices_64);
} else {
I32Array indices_32(n_nodes);
save_tree(&indices_32);
out["split_indices"] = std::move(indices_32);
}

this->SaveCategoricalSplit(&out);
Expand All @@ -1185,7 +1206,7 @@ void RegTree::SaveModel(Json* p_out) const {
out["left_children"] = std::move(lefts);
out["right_children"] = std::move(rights);
out["parents"] = std::move(parents);
out["split_indices"] = std::move(indices);

out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/cpp/tree/test_tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ TEST(Tree, JsonIO) {
ASSERT_EQ(get<I32Array const>(j_tree["left_children"]).size(), 3ul);
ASSERT_EQ(get<I32Array const>(j_tree["right_children"]).size(), 3ul);
ASSERT_EQ(get<I32Array const>(j_tree["parents"]).size(), 3ul);
ASSERT_EQ(get<I64Array const>(j_tree["split_indices"]).size(), 3ul);
ASSERT_EQ(get<I32Array const>(j_tree["split_indices"]).size(), 3ul);
ASSERT_EQ(get<F32Array const>(j_tree["split_conditions"]).size(), 3ul);
ASSERT_EQ(get<U8Array const>(j_tree["default_left"]).size(), 3ul);

Expand Down

0 comments on commit ad474ed

Please sign in to comment.