Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read json format of gbtree and gblinear #81

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
16 changes: 8 additions & 8 deletions compatibility.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@

This file is autogenerated by [compatibility_test.py](testscripts/compatibility_test.py)

## XGBOOST

| Case |0.72.1|0.82|0.90|
|----------------|------|----|----|
|XGIrisMulticlass| V | V | V |

## LIGHTGBM

| Case |2.0.10|2.0.11|2.0.12|2.1.0|2.1.1|2.1.2|2.2.0|2.2.1|2.2.2|2.2.3|2.3.0|
|------------------|------|------|------|-----|-----|-----|-----|-----|-----|-----|-----|
| LGBreastCancer | X | X | V | V | V | V | V | V | V | V | X |
| LGBreastCancer | X | X | V | V | V | V | V | V | V | V | V |
|LGIrisRandomForest| X | X | V | V | V | V | V | V | V | X | X |

## XGBOOST

| Case |0.72.1|0.82|0.90|1.0.0|1.1.0|1.2.0|1.3.1|1.3.2|1.3.3|1.4.0|1.4.1|
|----------------|------|----|----|-----|-----|-----|-----|-----|-----|-----|-----|
|XGIrisMulticlass| V | V | V | V | V | V | V | V | V | V | V |


## Details

X - not passed, V - passed

Generated 2019-11-08 18:06
Generated 2021-04-24 00:00
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
module github.com/dmitryikh/leaves

go 1.12
go 1.16

require github.com/stretchr/testify v1.7.0
11 changes: 11 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
42 changes: 25 additions & 17 deletions internal/xgbin/xgbin_io.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,44 @@ import (
// Note: XGBosst widely use int type which is machine depended. Go's int32 should cover most common case
// Note: Data structures' fields comments are take from original XGBoost source code

// LearnerModelParam - training parameter for regression.
// LearnerModelParamLegacy - training parameter for regression.
// from src/learner.cc
type LearnerModelParam struct {
type LearnerModelParamLegacy struct {
// global bias
BaseScore float32
BaseScore float32 `json:"base_score,string"`
// number of features
NumFeatures uint32
NumFeatures uint32 `json:"num_feature,string"`
// number of classes, if it is multi-class classification
NumClass int32
NumClass int32 `json:"num_class,string"`
// Model contain additional properties
ContainExtraAttrs int32
// Model contain eval metrics
ContainEvalMetrics int32
MajorVersion uint32
MinorVersion uint32
// reserved field
Reserved [29]int32
Reserved [27]int32
}

// GBTreeModelParam - model parameters
// from src/gbm/gbtree_model.h
type GBTreeModelParam struct {
// number of trees
NumTrees int32
NumTrees int32 `json:"num_trees,string"`
// number of roots
NumRoots int32
DeprecatedNumRoots int32
// number of features to be used by trees
NumFeature int32
DeprecatedNumFeature int32
// pad this space, for backward compatibility reason
Pad32bit int32
// deprecated padding space.
NumPbufferDeprecated int64
DeprecatedNumPbufferDeprecated int64
// how many output group a single instance can produce
// this affects the behavior of number of output we have:
// suppose we have n instance and k group, output will be k * n
NumOutputGroup int32
DeprecatedNumOutputGroup int32
// size of leaf vector needed in tree
SizeLeafVector int32
SizeLeafVector int32 `json:"size_leaf_vector,string"`
// reserved parameters
Reserved [32]int32
}
Expand All @@ -57,16 +59,16 @@ type TreeParam struct {
// number of start root
NumRoots int32
// total number of nodes
NumNodes int32
NumNodes int32 `json:"num_nodes,string"`
// number of deleted nodes
NumDeleted int32
// maximum depth, this is a statistics of the tree
MaxDepth int32
// number of features used for tree construction
NumFeature int32
NumFeature int32 `json:"num_feature,string"`
// leaf vector size, used for vector tree
// used to store more than one dimensional information in tree
SizeLeafVector int32
SizeLeafVector int32 `json:"size_leaf_vector,string"`
// reserved part, make sure alignment works for 64bit
Reserved [31]int32
}
Expand Down Expand Up @@ -123,7 +125,7 @@ type TreeModel struct {
Stats []RTreeNodeStat
// // leaf vector, that is used to store additional information
// LeafVector []float32
Param TreeParam
Param TreeParam `json:"tree_param"`
}

// GBTreeModel contains all input data related to gbtree model. Used just as a
Expand All @@ -140,7 +142,7 @@ type GBTreeModel struct {
// file. Used just as a container of input data for go implementation. Objects
// layout could be arbitrary
type ModelHeader struct {
Param LearnerModelParam
Param LearnerModelParamLegacy
NameObj string
NameGbm string
}
Expand Down Expand Up @@ -318,3 +320,9 @@ func ReadGBLinearModel(reader *bufio.Reader) (*GBLinearModel, error) {
}
return gbLinearModel, nil
}

func ReadBinf(reader *bufio.Reader) {
if peek, err := reader.Peek(4); err == nil && string(peek) == "binf" {
_, _ = reader.Read(make([]byte, 4))
}
}
8 changes: 4 additions & 4 deletions internal/xgbin/xgbin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func TestReadGBTree(t *testing.T) {
path := filepath.Join("..", "..", "testdata", "xgagaricus.model")
path := filepath.Join("..", "..", "testdata", "xgagaricus_previous_version.model")
reader, err := os.Open(path)
if err != nil {
t.Fatal(err)
Expand All @@ -33,9 +33,9 @@ func TestReadGBTree(t *testing.T) {
}
trueGBTreeModelParam := GBTreeModelParam{}
trueGBTreeModelParam.NumTrees = 3
trueGBTreeModelParam.NumRoots = 1
trueGBTreeModelParam.NumFeature = 127
trueGBTreeModelParam.NumOutputGroup = 1
trueGBTreeModelParam.DeprecatedNumRoots = 1
trueGBTreeModelParam.DeprecatedNumFeature = 127
trueGBTreeModelParam.DeprecatedNumOutputGroup = 1
if !reflect.DeepEqual(trueGBTreeModelParam, gBTreeModel.Param) {
t.Fatalf("unexpected GBTreeModelParam values (got %v)", gBTreeModel.Param)
}
Expand Down
10 changes: 10 additions & 0 deletions internal/xgjson/common_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package xgjson

type Objective struct {
Name string `json:"name"`
RegLossParam RegLossParam `json:"reg_loss_param"`
}

type RegLossParam struct {
ScalePosWeight string `json:"scale_pos_weight"`
}
21 changes: 21 additions & 0 deletions internal/xgjson/gblinear_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package xgjson

import "github.com/dmitryikh/leaves/internal/xgbin"

type GBLinearJson struct {
Learner GBLinearLearner `json:"learner"`
Version []int `json:"version"`
}

type GBLinearLearner struct {
FeatureNames []string `json:"feature_names"`
FeatureTypes []string `json:"feature_types"`
GradientBooster GBLinearBooster `json:"gradient_booster"`
Objective Objective `json:"objective"`
LearnerModelParam xgbin.LearnerModelParamLegacy `json:"learner_model_param"`
}

type GBLinearBooster struct {
Model xgbin.GBLinearModel `json:"model"`
Name string `json:"name"`
}
89 changes: 89 additions & 0 deletions internal/xgjson/gbtree_model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package xgjson

import "github.com/dmitryikh/leaves/internal/xgbin"

type GBTreeJson struct {
Learner GBTreeLearner `json:"learner"`
Version []int `json:"version"`
}

type GBTreeLearner struct {
FeatureNames []string `json:"feature_names"`
FeatureTypes []string `json:"feature_types"`
GradientBooster GBTreeBooster `json:"gradient_booster"`
Objective Objective `json:"objective"`
LearnerModelParam xgbin.LearnerModelParamLegacy `json:"learner_model_param"`
}

type GBTreeBooster struct {
Model GBTreeModel `json:"model"`
WeightDrop []float64 `json:"weight_drop"`
Name string `json:"name"`
}

type GBTreeModel struct {
GbTreeModelParam xgbin.GBTreeModelParam `json:"gbtree_model_param"`
Trees []*Tree `json:"trees"`
TreeInfo []int32 `json:"tree_info"`
}

type Tree struct {
TreeParam xgbin.TreeParam `json:"tree_param"`
Id int `json:"id"`
LossChanges []float32 `json:"loss_changes"`
SumHessian []float32 `json:"sum_hessian"`
BaseWeights []float32 `json:"base_weights"`
LeftChildren []int32 `json:"left_children"`
RightChildren []int32 `json:"right_children"`
Parents []int32 `json:"parents"`
SplitIndices []uint32 `json:"split_indices"`
SplitConditions []float32 `json:"split_conditions"`
SplitType []int32 `json:"split_type"`
DefaultLeft []bool `json:"default_left"`
Categories []int32 `json:"categories"`
CategoriesNodes []int32 `json:"categories_nodes"`
CategoriesSegments []int32 `json:"categories_segments"`
CategoricalSizes []int32 `json:"categorical_sizes"`
}

func (g *GBTreeModel) ToBinGBTreeModel() *xgbin.GBTreeModel {
param := g.GbTreeModelParam
trees := make([]*xgbin.TreeModel, param.NumTrees)
for idx, tree := range g.Trees {
trees[idx] = tree.toBinTreeModel()
}
treeInfo := g.TreeInfo
gbTreeModel := &xgbin.GBTreeModel{
Param: param,
Trees: trees,
TreeInfo: treeInfo,
}
return gbTreeModel
}

func (t *Tree) toBinTreeModel() *xgbin.TreeModel {
nodes := make([]xgbin.Node, t.TreeParam.NumNodes)
rTreeNodeStat := make([]xgbin.RTreeNodeStat, t.TreeParam.NumNodes)
for idx := range nodes {
nodes[idx].CRight = t.RightChildren[idx]
nodes[idx].CLeft = t.LeftChildren[idx]
nodes[idx].Parent = t.Parents[idx]
nodes[idx].Parent = int32(uint32(t.Parents[idx]) | 1 << 31)
if t.DefaultLeft[idx] {
t.SplitIndices[idx] |= 1 << 31
}
nodes[idx].SIndex = t.SplitIndices[idx]
nodes[idx].Info = t.SplitConditions[idx]
rTreeNodeStat[idx].BaseWeight = t.BaseWeights[idx]
rTreeNodeStat[idx].LossChg = t.LossChanges[idx]
rTreeNodeStat[idx].SumHess = t.SumHessian[idx]
}
treeParam := t.TreeParam
treeParam.NumRoots = 1
treeModel := &xgbin.TreeModel{
Nodes: nodes,
Stats: rTreeNodeStat,
Param: treeParam,
}
return treeModel
}
37 changes: 37 additions & 0 deletions internal/xgjson/reader.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package xgjson

import (
"encoding/json"
"fmt"
"io/ioutil"
)

func ReadGBTree(filePath string) (*GBTreeJson, error) {
bytes, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, err
}
gbTree := &GBTreeJson{}
if err := json.Unmarshal(bytes, gbTree); err != nil {
return nil, err
}
if gbTree.Learner.GradientBooster.Name != "gbtree" && gbTree.Learner.GradientBooster.Name != "dart"{
return nil, fmt.Errorf("wrong gbtree format, this reader can only read gbtree or dart")
}
return gbTree, nil
}

func ReadGBLinear(filePath string) (*GBLinearJson, error) {
bytes, err := ioutil.ReadFile(filePath)
if err != nil {
return nil, err
}
gbLinear := &GBLinearJson{}
if err := json.Unmarshal(bytes, gbLinear); err != nil {
return nil, err
}
if gbLinear.Learner.GradientBooster.Name != "gblinear" {
return nil, fmt.Errorf("wrong gblinear format, this reader can only read gblinear")
}
return gbLinear, nil
}
Loading