In [1]:
import json
import pickle
import treelite

In [2]:
model_path = (
    "models/sklearn/daily/10_trees/sklearn_alternative_10_trees_daily_2016-01-01.pkl"
)

In [3]:
with open(model_path, "rb") as f:
    model = pickle.load(f)

# Load Model into Treelite
treelite_model = treelite.sklearn.import_model(model)
# Load JSON Representation
treelite_model_json = json.loads(
    treelite_model.dump_as_json(pretty_print=False)
)

In [18]:
class Node:
    def __init__(
        self,
        node_id,
        split_feature_id,
        default_left,
        split_type,
        comparison_op,
        threshold,
        left_child,
        right_child,
        data_count,
        sum_hess,
        gain,
    ):
        self.node_id = node_id
        self.split_feature_id = split_feature_id
        self.default_left = default_left
        self.split_type = split_type
        self.comparison_op = comparison_op
        self.threshold = threshold
        self.left_child = left_child
        self.right_child = right_child
        self.data_count = data_count
        self.sum_hess = sum_hess
        self.gain = gain


def build_tree_from_json(json_data):
    nodes = {}
    for node_data in json_data:
        node_id = node_data["node_id"]
        split_feature_id = node_data["split_feature_id"]
        default_left = node_data["default_left"]
        split_type = node_data["split_type"]
        comparison_op = node_data["comparison_op"]
        threshold = node_data["threshold"]
        left_child = node_data["left_child"]
        right_child = node_data["right_child"]
        data_count = node_data["data_count"]
        sum_hess = node_data["sum_hess"]
        gain = node_data["gain"]

        nodes[node_id] = Node(
            node_id,
            split_feature_id,
            default_left,
            split_type,
            comparison_op,
            threshold,
            left_child,
            right_child,
            data_count,
            sum_hess,
            gain,
        )
    return nodes


def get_subtrees(nodes):
    all_subtrees = []

    def traverse(node_id):
        if node_id is None:
            return [None]

        left_subtrees = traverse(nodes[node_id].left_child)
        right_subtrees = traverse(nodes[node_id].right_child)

        node_subtrees = []
        for left in left_subtrees:
            for right in right_subtrees:
                subtree = [node_id]
                if left is not None:
                    subtree.extend(left)
                if right is not None:
                    subtree.extend(right)
                node_subtrees.append(subtree)

        all_subtrees.extend(node_subtrees)
        return node_subtrees

    traverse(0)  # Start traversal from the root node (assuming root is at node_id 0)
    return all_subtrees



KeyError: 3

In [22]:
tree_nodes = build_tree_from_json(treelite_model_json["trees"][0]["nodes"][:3])
tree_nodes

{0: <__main__.Node at 0x7f98cfa7fb20>,
 1: <__main__.Node at 0x7f98cfa7fac0>,
 2: <__main__.Node at 0x7f98cfa7fb50>}

In [None]:
all_subtrees = get_subtrees(tree_nodes)
print("All possible subtrees:")
for subtree in all_subtrees:
    print(subtree)