Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graphviz plot for multi-target tree. #10093

Merged
merged 2 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/xgboost/tree_model.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2014-2023 by Contributors
* Copyright 2014-2024, XGBoost Contributors
* \file tree_model.h
* \brief model structure for tree
* \author Tianqi Chen
Expand Down Expand Up @@ -688,6 +688,9 @@ class RegTree : public Model {
}
return (*this)[nidx].DefaultLeft();
}
[[nodiscard]] bst_node_t DefaultChild(bst_node_t nidx) const {
return this->DefaultLeft(nidx) ? this->LeftChild(nidx) : this->RightChild(nidx);
}
[[nodiscard]] bool IsRoot(bst_node_t nidx) const {
if (IsMultiTarget()) {
return nidx == kRoot;
Expand Down
154 changes: 95 additions & 59 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright 2015-2023, XGBoost Contributors
* Copyright 2015-2024, XGBoost Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
Expand All @@ -8,14 +8,15 @@
#include <xgboost/json.h>
#include <xgboost/tree_model.h>

#include <array> // for array
#include <cmath>
#include <iomanip>
#include <limits>
#include <sstream>
#include <type_traits>

#include "../common/categorical.h"
#include "../common/common.h" // for EscapeU8
#include "../common/common.h" // for EscapeU8
#include "../predictor/predict_fn.h"
#include "io_utils.h" // for GetElem
#include "param.h"
Expand All @@ -31,26 +32,50 @@ namespace tree {
DMLC_REGISTER_PARAMETER(TrainParam);
}

namespace {
template <typename Float>
std::enable_if_t<std::is_floating_point_v<Float>, std::string> ToStr(Float value) {
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str();
}

template <typename Float>
std::string ToStr(linalg::VectorView<Float> value, bst_target_t limit) {
int32_t constexpr kFloatMaxPrecision = std::numeric_limits<float>::max_digits10;
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision);
if (value.Size() == 1) {
ss << value(0);
return ss.str();
}
CHECK_GE(limit, 2);
auto n = std::min(static_cast<bst_target_t>(value.Size() - 1), limit - 1);
ss << "[";
for (std::size_t i = 0; i < n; ++i) {
ss << value(i) << ", ";
}
if (value.Size() > limit) {
ss << "..., ";
}
ss << value(value.Size() - 1) << "]";
return ss.str();
}
} // namespace
/*!
* \brief Base class for dump model implementation, modeling closely after code generator.
*/
class TreeGenerator {
protected:
static int32_t constexpr kFloatMaxPrecision =
std::numeric_limits<bst_float>::max_digits10;
FeatureMap const& fmap_;
std::stringstream ss_;
bool const with_stats_;

template <typename Float>
static std::string ToStr(Float value) {
static_assert(std::is_floating_point<Float>::value,
"Use std::to_string instead for non-floating point values.");
std::stringstream ss;
ss << std::setprecision(kFloatMaxPrecision) << value;
return ss.str();
}

static std::string Tabs(uint32_t n) {
std::string res;
for (uint32_t i = 0; i < n; ++i) {
Expand Down Expand Up @@ -258,10 +283,10 @@ class TextGenerator : public TreeGenerator {
kLeafTemplate,
{{"{tabs}", SuperT::Tabs(depth)},
{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{leaf}", ToStr(tree[nid].LeafValue())},
{"{stats}", with_stats_ ?
SuperT::Match(kStatTemplate,
{{"{cover}", SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
{{"{cover}", ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}

Expand Down Expand Up @@ -311,14 +336,14 @@ class TextGenerator : public TreeGenerator {
static std::string const kQuantitiveTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
auto cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
}

std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
auto cond = tree[nid].SplitCond();
static std::string const kNodeTemplate =
"{tabs}{nid}:[{fname}<{cond}] yes={left},no={right},missing={missing}";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
}

std::string Categorical(RegTree const &tree, int32_t nid,
Expand All @@ -336,8 +361,8 @@ class TextGenerator : public TreeGenerator {
static std::string const kStatTemplate = ",gain={loss_chg},cover={sum_hess}";
std::string const result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
return result;
}

Expand Down Expand Up @@ -393,11 +418,11 @@ class JsonGenerator : public TreeGenerator {
std::string result = SuperT::Match(
kLeafTemplate,
{{"{nid}", std::to_string(nid)},
{"{leaf}", SuperT::ToStr(tree[nid].LeafValue())},
{"{leaf}", ToStr(tree[nid].LeafValue())},
{"{stat}", with_stats_ ? SuperT::Match(
kStatTemplate,
{{"{sum_hess}",
SuperT::ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
ToStr(tree.Stat(nid).sum_hess)}}) : ""}});
return result;
}

Expand Down Expand Up @@ -468,7 +493,7 @@ class JsonGenerator : public TreeGenerator {
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
bst_float cond = tree[nid].SplitCond();
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kQuantitiveTemplate, ToStr(cond), depth);
}

std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t depth) const override {
Expand All @@ -477,16 +502,16 @@ class JsonGenerator : public TreeGenerator {
R"I( "nodeid": {nid}, "depth": {depth}, "split": "{fname}", )I"
R"I("split_condition": {cond}, "yes": {left}, "no": {right}, )I"
R"I("missing": {missing})I";
return SplitNodeImpl(tree, nid, kNodeTemplate, SuperT::ToStr(cond), depth);
return SplitNodeImpl(tree, nid, kNodeTemplate, ToStr(cond), depth);
}

std::string NodeStat(RegTree const& tree, int32_t nid) const override {
static std::string kStatTemplate =
R"S(, "gain": {loss_chg}, "cover": {sum_hess})S";
auto result = SuperT::Match(
kStatTemplate,
{{"{loss_chg}", SuperT::ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", SuperT::ToStr(tree.Stat(nid).sum_hess)}});
{{"{loss_chg}", ToStr(tree.Stat(nid).loss_chg)},
{"{sum_hess}", ToStr(tree.Stat(nid).sum_hess)}});
return result;
}

Expand Down Expand Up @@ -622,11 +647,11 @@ class GraphvizGenerator : public TreeGenerator {

protected:
template <bool is_categorical>
std::string BuildEdge(RegTree const &tree, bst_node_t nid, int32_t child, bool left) const {
std::string BuildEdge(RegTree const &tree, bst_node_t nidx, int32_t child, bool left) const {
static std::string const kEdgeTemplate =
" {nid} -> {child} [label=\"{branch}\" color=\"{color}\"]\n";
// Is this the default child for missing value?
bool is_missing = tree[nid].DefaultChild() == child;
bool is_missing = tree.DefaultChild(nidx) == child;
std::string branch;
if (is_categorical) {
branch = std::string{left ? "no" : "yes"} + std::string{is_missing ? ", missing" : ""};
Expand All @@ -635,7 +660,7 @@ class GraphvizGenerator : public TreeGenerator {
}
std::string buffer =
SuperT::Match(kEdgeTemplate,
{{"{nid}", std::to_string(nid)},
{{"{nid}", std::to_string(nidx)},
{"{child}", std::to_string(child)},
{"{color}", is_missing ? param_.yes_color : param_.no_color},
{"{branch}", branch}});
Expand All @@ -644,68 +669,77 @@ class GraphvizGenerator : public TreeGenerator {

// Only indicator is different, so we combine all different node types into this
// function.
std::string PlainNode(RegTree const& tree, int32_t nid, uint32_t) const override {
auto split_index = tree[nid].SplitIndex();
auto cond = tree[nid].SplitCond();
std::string PlainNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
auto split_index = tree.SplitIndex(nidx);
auto cond = tree.SplitCond(nidx);
static std::string const kNodeTemplate = " {nid} [ label=\"{fname}{<}{cond}\" {params}]\n";

bool has_less =
(split_index >= fmap_.Size()) || fmap_.TypeOf(split_index) != FeatureMap::kIndicator;
std::string result =
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nid)},
SuperT::Match(kNodeTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{<}", has_less ? "<" : ""},
{"{cond}", has_less ? SuperT::ToStr(cond) : ""},
{"{cond}", has_less ? ToStr(cond) : ""},
{"{params}", param_.condition_node_params}});

result += BuildEdge<false>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<false>(tree, nid, tree[nid].RightChild(), false);
result += BuildEdge<false>(tree, nidx, tree.LeftChild(nidx), true);
result += BuildEdge<false>(tree, nidx, tree.RightChild(nidx), false);

return result;
};

std::string Categorical(RegTree const& tree, int32_t nid, uint32_t) const override {
std::string Categorical(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
static std::string const kLabelTemplate =
" {nid} [ label=\"{fname}:{cond}\" {params}]\n";
auto cats = GetSplitCategories(tree, nid);
auto cats = GetSplitCategories(tree, nidx);
auto cats_str = PrintCatsAsSet(cats);
auto split_index = tree[nid].SplitIndex();
auto split_index = tree.SplitIndex(nidx);

std::string result =
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nid)},
SuperT::Match(kLabelTemplate, {{"{nid}", std::to_string(nidx)},
{"{fname}", GetFeatureName(fmap_, split_index)},
{"{cond}", cats_str},
{"{params}", param_.condition_node_params}});

result += BuildEdge<true>(tree, nid, tree[nid].LeftChild(), true);
result += BuildEdge<true>(tree, nid, tree[nid].RightChild(), false);
result += BuildEdge<true>(tree, nidx, tree.LeftChild(nidx), true);
result += BuildEdge<true>(tree, nidx, tree.RightChild(nidx), false);

return result;
}

std::string LeafNode(RegTree const& tree, int32_t nid, uint32_t) const override {
static std::string const kLeafTemplate =
" {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
auto result = SuperT::Match(kLeafTemplate, {
{"{nid}", std::to_string(nid)},
{"{leaf-value}", ToStr(tree[nid].LeafValue())},
{"{params}", param_.leaf_node_params}});
return result;
};
std::string LeafNode(RegTree const& tree, bst_node_t nidx, uint32_t) const override {
static std::string const kLeafTemplate = " {nid} [ label=\"leaf={leaf-value}\" {params}]\n";
// hardcoded limit to avoid dumping long arrays into dot graph.
bst_target_t constexpr kLimit{3};
if (tree.IsMultiTarget()) {
auto value = tree.GetMultiTargetTree()->LeafValue(nidx);
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
{"{leaf-value}", ToStr(value, kLimit)},
{"{params}", param_.leaf_node_params}});
return result;
} else {
auto value = tree[nidx].LeafValue();
auto result = SuperT::Match(kLeafTemplate, {{"{nid}", std::to_string(nidx)},
{"{leaf-value}", ToStr(value)},
{"{params}", param_.leaf_node_params}});
return result;
}
}

std::string BuildTree(RegTree const& tree, int32_t nid, uint32_t depth) override {
if (tree[nid].IsLeaf()) {
return this->LeafNode(tree, nid, depth);
std::string BuildTree(RegTree const& tree, bst_node_t nidx, uint32_t depth) override {
if (tree.IsLeaf(nidx)) {
return this->LeafNode(tree, nidx, depth);
}
static std::string const kNodeTemplate = "{parent}\n{left}\n{right}";
auto node = tree.GetSplitTypes()[nid] == FeatureType::kCategorical
? this->Categorical(tree, nid, depth)
: this->PlainNode(tree, nid, depth);
auto node = tree.GetSplitTypes()[nidx] == FeatureType::kCategorical
? this->Categorical(tree, nidx, depth)
: this->PlainNode(tree, nidx, depth);
auto result = SuperT::Match(
kNodeTemplate,
{{"{parent}", node},
{"{left}", this->BuildTree(tree, tree[nid].LeftChild(), depth+1)},
{"{right}", this->BuildTree(tree, tree[nid].RightChild(), depth+1)}});
{"{left}", this->BuildTree(tree, tree.LeftChild(nidx), depth+1)},
{"{right}", this->BuildTree(tree, tree.RightChild(nidx), depth+1)}});
return result;
}

Expand Down Expand Up @@ -733,7 +767,9 @@ XGBOOST_REGISTER_TREE_IO(GraphvizGenerator, "dot")
constexpr bst_node_t RegTree::kRoot;

std::string RegTree::DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const {
CHECK(!IsMultiTarget());
if (this->IsMultiTarget() && format != "dot") {
LOG(FATAL) << format << " tree dump " << MTNotImplemented();
}
std::unique_ptr<TreeGenerator> builder{TreeGenerator::Create(format, fmap, with_stats)};
builder->BuildTree(*this);

Expand Down
37 changes: 34 additions & 3 deletions tests/cpp/tree/test_multi_target_tree_model.cc
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
/**
* Copyright 2023 by XGBoost Contributors
* Copyright 2023-2024, XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/context.h> // for Context
#include <xgboost/multi_target_tree_model.h>
#include <xgboost/tree_model.h> // for RegTree

namespace xgboost {
TEST(MultiTargetTree, JsonIO) {
namespace {
auto MakeTreeForTest() {
bst_target_t n_targets{3};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
ASSERT_TRUE(tree.IsMultiTarget());
CHECK(tree.IsMultiTarget());
linalg::Vector<float> base_weight{{1.0f, 2.0f, 3.0f}, {3ul}, DeviceOrd::CPU()};
linalg::Vector<float> left_weight{{2.0f, 3.0f, 4.0f}, {3ul}, DeviceOrd::CPU()};
linalg::Vector<float> right_weight{{3.0f, 4.0f, 5.0f}, {3ul}, DeviceOrd::CPU()};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, base_weight.HostView(),
left_weight.HostView(), right_weight.HostView());
return tree;
}
} // namespace

TEST(MultiTargetTree, JsonIO) {
auto tree = MakeTreeForTest();
ASSERT_EQ(tree.NumNodes(), 3);
ASSERT_EQ(tree.NumTargets(), 3);
ASSERT_EQ(tree.GetMultiTargetTree()->Size(), 3);
Expand Down Expand Up @@ -44,4 +51,28 @@ TEST(MultiTargetTree, JsonIO) {
loaded.SaveModel(&jtree1);
check_jtree(jtree1, tree);
}

TEST(MultiTargetTree, DumpDot) {
auto tree = MakeTreeForTest();
auto n_features = tree.NumFeatures();
FeatureMap fmap;
for (bst_feature_t f = 0; f < n_features; ++f) {
auto name = "feat_" + std::to_string(f);
fmap.PushBack(f, name.c_str(), "q");
}
auto str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find("leaf=[2, 3, 4]"), std::string::npos);
ASSERT_NE(str.find("leaf=[3, 4, 5]"), std::string::npos);

{
bst_target_t n_targets{4};
bst_feature_t n_features{4};
RegTree tree{n_targets, n_features};
linalg::Vector<float> weight{{1.0f, 2.0f, 3.0f, 4.0f}, {4ul}, DeviceOrd::CPU()};
tree.ExpandNode(RegTree::kRoot, /*split_idx=*/1, 0.5f, true, weight.HostView(),
weight.HostView(), weight.HostView());
auto str = tree.DumpModel(fmap, true, "dot");
ASSERT_NE(str.find("leaf=[1, 2, ..., 4]"), std::string::npos);
}
}
} // namespace xgboost
Loading