Run notebook with

```bash
uv run --with uv --with jupyter jupyter lab
```

In [1]:
!uv pip install onnx onnxruntime onnxmltools skl2onnx scikit-learn xmltodict

[2mAudited [1m6 packages[0m [2min 2ms[0m[0m


In [2]:
from dataclasses import dataclass
import base64

import numpy as np
import onnx
import xmltodict

## Parse FastBDT weights

Looking into
- https://github.com/thomaskeck/FastBDT/blob/master/src/Classifier.cxx#L209
- https://github.com/thomaskeck/FastBDT/blob/master/src/FastBDT_IO.cxx
- https://github.com/thomaskeck/FastBDT/blob/master/include/FastBDT_IO.h

In [3]:
with open("/home/nikolai/code/basf2/mva/methods/tests/FastBDTv5.xml") as f:
    data = base64.b64decode(xmltodict.parse(f"<root>{f.read().split('?>')[1]}</root>")["root"]["FastBDT_Weightfile"]+"===")

Data is read from space separated numbers

In [4]:
tokens = data.split()
str(tokens[:50])

"[b'1', b'200', b'3', b'3', b'8', b'8', b'8', b'0.1', b'1', b'0', b'-1', b'3', b'0', b'0', b'0', b'1', b'3', b'8', b'257', b'1.700299740e+00', b'1.847970366e+00', b'1.823094368e+00', b'1.871596694e+00', b'1.810523272e+00', b'1.835662246e+00', b'1.859904289e+00', b'1.884671807e+00', b'1.803922415e+00', b'1.816829801e+00', b'1.829383850e+00', b'1.841796398e+00', b'1.854024768e+00', b'1.865672946e+00', b'1.877931833e+00', b'1.891820669e+00', b'1.800248265e+00', b'1.807294846e+00', b'1.813672304e+00', b'1.819913387e+00', b'1.826185822e+00', b'1.832525253e+00', b'1.838704467e+00', b'1.844905376e+00', b'1.850979686e+00', b'1.857039809e+00', b'1.862837672e+00', b'1.868579149e+00', b'1.874693990e+00', b'1.881253123e+00', b'1.888190508e+00']"

In [5]:
import logging
logger = logging.getLogger()
logger.setLevel(logging.INFO)
logging.basicConfig()

In [6]:
def read(tokens, conv=int):
    logger.debug(f"read {conv}")
    return conv(next(tokens))

def read_vector(tokens, conv=float):
    logger.debug(f"read vector<{conv}>")
    size = int(next(tokens))
    return [conv(next(tokens)) for i in range(size)]

def read_vector_feature_binning(tokens):
    logger.debug(f"read vector of feature binning")
    out = []
    size = read(tokens, int)
    for i in range(size):
        n_levels = read(tokens, int)
        binning = read_vector(tokens, float)
        out.append((n_levels, binning))
    return out

@dataclass
class Cut:
    feature: int
    index: ...
    gain: float
    valid: int

    @classmethod
    def from_tokens(cls, tokens, conv=float):
        logger.debug(f"Read Cut<{conv}>")
        feature = read(tokens, int)
        index = read(tokens, conv)
        valid = read(tokens, int)
        gain = read(tokens, float)
        return cls(feature, index, gain, valid)

@dataclass
class Tree:
    cuts: list[Cut]
    nEntries: int
    purities: float
    boost_weights: list[float]

    @classmethod
    def from_tokens(cls, tokens, conv=float):
        logger.debug(f"Read Tree<{conv}>")
        size = read(tokens, int)
        cuts = []
        for i in range(size):
            cuts.append(Cut.from_tokens(tokens, conv))
        boost_weights = read_vector(tokens, float)
        purities = read_vector(tokens, float)
        nEntries = read_vector(tokens, float)
        return cls(cuts, nEntries, purities, boost_weights)

@dataclass
class Forest:
    f0: float
    shrinkage: float
    transform2probability: list[bool]
    trees: list[Tree]

    @classmethod
    def from_tokens(cls, tokens, conv=float):
        logger.debug(f"Read Forest<{conv}>")
        f0 = read(tokens, float)
        shrinkage = read(tokens, float)
        transform2probability = read(tokens, bool)
        size = read(tokens, int)
        trees = []
        for i in range(size):
            trees.append(Tree.from_tokens(tokens, conv))
        return cls(f0, shrinkage, transform2probability, trees)

In [7]:
read_vector(iter(data.split()[3:]), int)

[8, 8, 8]

In [8]:
def debugging_iter(it):
    for x in it:
        print(f"Reading {x}")
        yield x

In [9]:
@dataclass
class BDT:
    version: int
    n_trees: int
    depth: int
    binning: list[int]
    shrinkage: float
    subsample: float
    sPlot: bool
    flatnessLoss: float
    purityTransformation: list[bool]
    transform2probability: bool
    featureBinning: list[tuple[int, list[float]]]
    purityBinning: list[int]
    numberOfFeatures: int
    numberOfFinalFeatures: int
    numberOfFlatnessFeatures: int
    can_use_fast_forest: bool
    forest: Forest
    binned_forest: Forest

    @classmethod
    def from_tokens(cls, tokens):
        return cls(
            version=read(tokens, int),
            n_trees=read(tokens, int),
            depth=read(tokens, int),
            binning=read_vector(tokens, int),
            shrinkage=read(tokens, float),
            subsample=read(tokens, float),
            sPlot=read(tokens, bool),
            flatnessLoss=read(tokens, float),
            purityTransformation=read_vector(tokens, bool),
            transform2probability=read(tokens, bool),
            featureBinning=read_vector_feature_binning(tokens),
            purityBinning=read_vector(tokens, int),
            numberOfFeatures=read(tokens, int),
            numberOfFinalFeatures=read(tokens, int),
            numberOfFlatnessFeatures=read(tokens, int),
            can_use_fast_forest=read(tokens, bool),
            forest=Forest.from_tokens(tokens, float),
            binned_forest=Forest.from_tokens(tokens, int),
        )

In [10]:
bdt = BDT.from_tokens(iter(data.split()))

In [11]:
bdt.__dict__.keys()

dict_keys(['version', 'n_trees', 'depth', 'binning', 'shrinkage', 'subsample', 'sPlot', 'flatnessLoss', 'purityTransformation', 'transform2probability', 'featureBinning', 'purityBinning', 'numberOfFeatures', 'numberOfFinalFeatures', 'numberOfFlatnessFeatures', 'can_use_fast_forest', 'forest', 'binned_forest'])

In [12]:
len(bdt.forest.trees[0].cuts)

7

In [13]:
bdt.forest.trees[0]

Tree(cuts=[Cut(feature=0, index=1.836806893, gain=1062.871, valid=1), Cut(feature=0, index=1.820293784, gain=67.60223, valid=1), Cut(feature=2, index=1.23523581, gain=153.043, valid=1), Cut(feature=0, index=1.807693601, gain=7.638123, valid=1), Cut(feature=0, index=1.831332564, gain=12.39661, valid=1), Cut(feature=0, index=1.886921883, gain=98.42749, valid=1), Cut(feature=0, index=1.882526875, gain=43.83337, valid=1)], nEntries=[36198.39062, 9302.176758, 26905.32422, 4896.272461, 4401.743164, 19258.10352, 7651.693848, 2047.376343, 2849.948242, 2790.36084, 1611.965332, 16794.15039, 2463.5271, 6486.14502, 1165.553955], purities=[0.5001226068, 0.208625108, 0.6007550359, 0.1277961731, 0.2987351418, 0.553139925, 0.7201112509, 0.0811952427, 0.1612261832, 0.2583843768, 0.3684764504, 0.5804541707, 0.3669680953, 0.7521873713, 0.5416552424], boost_weights=[0.0002458892995, -0.5826275945, 0.2019119263, -0.7445226312, -0.4025033712, 0.1062627062, 0.4403621256, -0.8376281857, -0.6774736047, -0.4831

In [14]:
bdt.numberOfFeatures, bdt.numberOfFinalFeatures, bdt.numberOfFlatnessFeatures

(3, 3, 0)

## Learn how onnx does BDTs

see https://onnx.ai/onnx/operators/onnx_aionnxml_TreeEnsemble.html

Note: seems the conversion tools for sklearn and xgboost still (2025-09) use the "deprecated" `TreeEnsembleClassifier` nodes - so let's try to use the new `TreeEnsemble` here:

(see also https://github.com/onnx/onnx/pull/5874)

In [15]:
from onnx.backend.test.case.node import expect, _extract_value_info

In [16]:
from onnx.helper import make_tensor, make_value_info

In [17]:
node = onnx.helper.make_node(
    "TreeEnsemble",
    ["input"],
    ["output"],
    domain="ai.onnx.ml",
    n_targets=2,
    membership_values=None,
    nodes_missing_value_tracks_true=None,
    nodes_hitrates=None,
    aggregate_function=1,
    post_transform=0,
    tree_roots=[0],
    nodes_modes=make_tensor(
        "nodes_modes",
        onnx.TensorProto.UINT8,
        (3,),
        np.array([0, 0, 0], dtype=np.uint8),
    ),
    nodes_featureids=[0, 0, 0],
    nodes_splits=make_tensor(
        "nodes_splits",
        onnx.TensorProto.DOUBLE,
        (3,),
        np.array([3.14, 1.2, 4.2], dtype=np.float64),
    ),
    nodes_truenodeids=[1, 0, 1],
    nodes_trueleafs=[0, 1, 1],
    nodes_falsenodeids=[2, 2, 3],
    nodes_falseleafs=[0, 1, 1],
    leaf_targetids=[0, 1, 0, 1],
    leaf_weights=make_tensor(
        "leaf_weights",
        onnx.TensorProto.DOUBLE,
        (4,),
        np.array([5.23, 12.12, -12.23, 7.21], dtype=np.float64),
    ),
)

x = np.array([1.2, 3.4, -0.12, 1.66, 4.14, 1.77], np.float64).reshape(3, 2)
y = np.array([[5.23, 0], [5.23, 0], [0, 12.12]], dtype=np.float64)

In [18]:
# following loosely https://onnx.ai/onnx/intro/python.html

In [19]:
from onnx import TensorProto
from onnx.helper import (
    make_model, make_node, make_graph,
    make_tensor_value_info)
from onnx.checker import check_model

In [20]:
graph = make_graph(
    [node],
    "bla",
    [make_tensor_value_info("input", TensorProto.DOUBLE, [3, 2])],
    [make_tensor_value_info("output", TensorProto.DOUBLE, [3, 2])],
)

In [21]:
model = make_model(graph, opset_imports=[onnx.helper.make_opsetid("ai.onnx.ml", 5)], ir_version=10)

In [22]:
check_model(model)

In [23]:
import onnxruntime

In [24]:
sess = onnxruntime.InferenceSession(model.SerializeToString())

In [25]:
sess.run(["output"], {"input": np.random.rand(3, 2)})

[array([[5.23, 0.  ],
        [5.23, 0.  ],
        [5.23, 0.  ]])]

In [26]:
assert (sess.run(["output"], {"input": x}) == y).all()

## Now convert the FastBDT stuff into what onnx needs

Have to flatten all trees into single lists for nodes and leafs

FastBDT uses trees with fixed depth, so `2**depth` is the number of leafs

In [27]:
n_leafs = 2**(bdt.depth)
n_leafs

8

We have only one target

In [28]:
leaf_targetids = [0] * len(bdt.forest.trees) * n_leafs

last `n_leafs` entries are `boost_weights` for leafs

**TODO:** handle the "invalid" cuts - this is a bit ugly since we need to add extra leafs for these (FastBDT seems to treat these as terminal nodes and there will be a unique boost weight assigned no matter the value - so maybe we could reference both false and true branch to the same leaf in these cases. If we actually keep the index value `NaN` it will always be sent to the false branch probably.)

**TODO:** Treatment of nan **inputs** is even more annoying. I'm still not quite sure i have fully understood the behavior of FastBDT to "stop" when a NaN value is seen. As far as i currently understand there is then a possible leaf for every internal node where it might stop. Maybe the only way this could be implemented is to add for *every* node an additional node that decides the leaf in case of nan input (there will be one possible leaf for every original node as far as i understood). The comparison value of these nodes would be NaN such that they evaluate to false for every value, but then set `nodes_missing_value_tracks_true` to 1, such that for NaN input it will evaluate to true and send them to the special leafs.

In [29]:
leaf_weights = [item * bdt.shrinkage for tree in bdt.forest.trees for item in tree.boost_weights[-n_leafs:]]
leaf_weights = make_tensor(
    "leaf_weights",
    onnx.TensorProto.DOUBLE,
    (len(leaf_weights),),
    np.array(leaf_weights, dtype=np.float64)
)

> * **nodes_falseleafs - INTS** (required): 
> 1 if false branch is leaf for each node and 0 if an interior node. To represent a tree that is a leaf (only has one node), one can do so by having a    single nodes_* entry with true and false branches referencing the same leaf_* entry

The last `n_leafs // 2` nodes should be the terminal nodes, where both false and true branches are leafs.

In [30]:
n_nodes = len(bdt.forest.trees[0].cuts)
n_terminal_nodes = n_leafs // 2
n_internal_nodes = n_nodes - n_terminal_nodes
n_nodes, n_internal_nodes, n_terminal_nodes

(7, 3, 4)

In [31]:
nodes_falseleafs = ([0] * n_internal_nodes + [1] * (n_leafs // 2)) * len(bdt.forest.trees)
nodes_trueleafs = list(nodes_falseleafs)
#nodes_falseleafs

> **nodes_falsenodeids - INTS** (required) :
> If `nodes_falseleafs` is false at an entry, this represents the position of the false branch node. This position can be used to index into a `nodes_*` entry. If `nodes_falseleafs` is false, it is an index into the `leaf_*` attributes.

Excerpt from FastBDT code:

```c++
            // Perform the cut of the given node and update the node.
            // Either the event is passed to the left child node (which has
            // the position 2*node in the next layer) or to the right
            // (which has the position 2*node + 1 in the next layer)
            node = (node << 1) + static_cast<unsigned int>(value >= cut.index);
```

So we will use all cut indices such that if `value >= cut.index` (let's call this "true") we go to node `2*node + 1`, `2*node` otherwise.

Here we will have global indices, so need to count offsets up while looping through the trees.

In [32]:
n_internal_nodes, list(range(1, n_internal_nodes + 1)), list(range(n_internal_nodes + 1, n_nodes + 1)), n_nodes

(3, [1, 2, 3], [4, 5, 6, 7], 7)

In [33]:
nodes_falsenodeids = []
nodes_truenodeids = []
node_offset = 0
i_leaf = 0
for tree in bdt.forest.trees:
    for node in range(1, n_internal_nodes + 1): # iterate with 1-based indices, but need to fill in 0-based ones
        nodes_falsenodeids.append(2 * node + node_offset - 1)
        nodes_truenodeids.append(2 * node + 1 + node_offset - 1)
    for node in range(n_internal_nodes + 1, n_nodes + 1):
        nodes_falsenodeids.append(i_leaf); i_leaf += 1
        nodes_truenodeids.append(i_leaf); i_leaf += 1
    node_offset += len(tree.cuts)

`nodes_featureids` is easy:

In [34]:
nodes_featureids = []
for tree in bdt.forest.trees:
    for cut in tree.cuts:
        nodes_featureids.append(cut.feature)

> * **nodes_modes - TENSOR** (required) :
> The comparison operation performed by the node. This is encoded as an enumeration of 0 (‘BRANCH_LEQ’), 1 (‘BRANCH_LT’), 2 (‘BRANCH_GTE’), 3 (‘BRANCH_GT’), 4 (‘BRANCH_EQ’), 5 (‘BRANCH_NEQ’), and 6 (‘BRANCH_MEMBER’). Note this is a tensor of type uint8.


In [35]:
BRANCH_GTE = 3 # i guess it means value >= index ?
shape = (len(nodes_featureids),)
nodes_modes = make_tensor(
    "nodes_modes",
    onnx.TensorProto.UINT8,
    shape,
    np.full(shape, BRANCH_GTE, dtype=np.uint8),
)

`nodes_splits` is also easy, just the cut indices:

In [36]:
nodes_splits = []
for tree in bdt.forest.trees:
    for cut in tree.cuts:
        nodes_splits.append(cut.index)
nodes_splits = make_tensor(
    "nodes_splits",
    onnx.TensorProto.DOUBLE,
    (len(nodes_splits),),
    np.array(nodes_splits, dtype=np.float64)
)

and `tree_roots` as well, first node for every tree is the root node

In [37]:
tree_roots = []
i = 0
for tree in bdt.forest.trees:
    tree_roots.append(i)
    i += len(tree.cuts)

build the onnx tree ensemble:

In [38]:
node = onnx.helper.make_node(
    "TreeEnsemble",
    ["input"],
    ["output"],
    domain="ai.onnx.ml",
    n_targets=1,
    membership_values=None,
    nodes_missing_value_tracks_true=None,
    nodes_hitrates=None,
    aggregate_function=1,
    post_transform=0,
    tree_roots=tree_roots,
    nodes_modes=nodes_modes,
    nodes_featureids=nodes_featureids,
    nodes_splits=nodes_splits,
    nodes_truenodeids=nodes_truenodeids,
    nodes_trueleafs=nodes_trueleafs,
    nodes_falsenodeids=nodes_falsenodeids,
    nodes_falseleafs=nodes_falseleafs,
    leaf_targetids=leaf_targetids,
    leaf_weights=leaf_weights,
)

In [39]:
graph = make_graph(
    [node],
    "bla",
    [make_tensor_value_info("input", TensorProto.DOUBLE, [1, bdt.numberOfFeatures])],
    [make_tensor_value_info("output", TensorProto.DOUBLE, [1, 1])],
)

In [40]:
model = make_model(graph, opset_imports=[onnx.helper.make_opsetid("ai.onnx.ml", 5)], ir_version=10)

In [41]:
check_model(model)

In [42]:
import onnxruntime

In [43]:
sess = onnxruntime.InferenceSession(model.SerializeToString())

In [44]:
sess.run(["output"], {"input": np.random.rand(1, 3)*3})

[array([[-1.80847116]])]

In [45]:
bdt.forest.trees[0].boost_weights

[0.0002458892995,
 -0.5826275945,
 0.2019119263,
 -0.7445226312,
 -0.4025033712,
 0.1062627062,
 0.4403621256,
 -0.8376281857,
 -0.6774736047,
 -0.4831543863,
 -0.2630733252,
 0.1607446969,
 -0.2660554945,
 0.5045183301,
 0.08330724388]

## Crosscheck with what i think should be happening:

In [46]:
def apply(bdt, values):
    out = 0
    for tree in bdt.forest.trees:
        node = 1
        while node <= len(tree.cuts):
            cut = tree.cuts[node - 1]
            value = values[cut.feature]
            node = (node << 1) + int(value >= cut.index)
        out += tree.boost_weights[node - 1] * bdt.shrinkage
    return out

In [47]:
inp = np.random.rand(3) * 2
apply(bdt, inp)

-1.9565154768815716

In [48]:
sess.run(["output"], {"input": np.array([inp])})

[array([[-1.95651548]])]

In [49]:
nodes_splits.double_data[0]

1.836806893

In [50]:
def apply2(values):
    out = 0
    for i_root in tree_roots:
        node = i_root
        while True:
            index = nodes_splits.double_data[node]
            value = values[nodes_featureids[node]]
            if value >= index:
                next_node = nodes_truenodeids[node]
                if nodes_trueleafs[node]:
                    out += leaf_weights.double_data[next_node]
                    break
            else:
                next_node = nodes_falsenodeids[node]
                if nodes_falseleafs[node]:
                    out += leaf_weights.double_data[next_node]
                    break
            node = next_node
    return out

In [51]:
apply2(inp)

-1.9565154768815716

In [52]:
with open("model.onnx", "wb") as f:
    onnx.save(model, f)