In [None]:
import time

from functools import partial

from sklearn.tree import DecisionTreeRegressor
from sklearn.tree._tree import Tree
from sklearn.utils import check_random_state

import numpy as np

import pydot

import jax
import jax.numpy as jnp
from jax import grad, jit

# n_actions: from the environment
def get_discretized_tree(tree_params, n_features_in, n_actions, prune=True):
    """
    Returns a scikit-learn Tree object with the pruned and
    discretized decision tree policy.
    """
    tree = Tree(n_features_in, np.array([n_actions]), 1)

    tree_params = tree_params.copy()

    def prune_tree_rec(node_id=0):
        left_id = tree_params["params"]["children_left"][node_id]
        right_id = tree_params["params"]["children_right"][node_id]

        # Do nothing if this is a leaf
        if left_id == right_id:
            return

        prune_tree_rec(left_id)
        prune_tree_rec(right_id)

        left_is_leaf = (
            tree_params["params"]["children_left"][left_id]
            == tree_params["params"]["children_right"][left_id]
        )
        right_is_leaf = (
            tree_params["params"]["children_left"][right_id]
            == tree_params["params"]["children_right"][right_id]
        )

        # If this is a node with two leaf children that both predict the same action
        # then replace the current node (in place) with the left child
        if left_is_leaf and right_is_leaf:
            left_action = jnp.argmax(tree_params["params"]["leaf_logits"][left_id])
            right_action = jnp.argmax(tree_params["params"]["leaf_logits"][right_id])
            if left_action == right_action:
                tree_params["params"]["children_right"][node_id] = left_id
                tree_params["params"]["leaf_logits"][node_id] = tree_params["params"]["leaf_logits"][left_id]

    if prune:
        prune_tree_rec()

    nodes = []
    values = []
    for node_id in range(len(tree_params["params"]["features"])):
        nodes.append(
            (
                tree_params["params"]["children_left"][node_id],
                tree_params["params"]["children_right"][node_id],
                tree_params["params"]["features"][node_id],
                tree_params["params"]["thresholds"][node_id],
                0, # impurity placeholder
                0, # n_node_samples placeholder
                0, # weighted_n_node_samples placeholder
                0  # missing_go_to_left placeholder
            )
        )
        action = jnp.argmax(tree_params["params"]["leaf_logits"][node_id])
        leaf_value = np.zeros(n_actions)
        leaf_value[action] = 1
        values.append(leaf_value.reshape(1, -1))

    node_count = len(nodes)

    def find_depth(node_id=0, depth=0):
        left_child = tree_params["params"]["children_left"][node_id]
        right_child = tree_params["params"]["children_right"][node_id]
        if left_child == right_child:
            return depth

        left_depth = find_depth(left_child, depth + 1)
        right_depth = find_depth(right_child, depth + 1)
        return max(left_depth, right_depth)

    max_depth = find_depth()

    nodes = np.array(
        nodes,
        dtype=[
            ("left_child", "<i8"),
            ("right_child", "<i8"),
            ("feature", "<i8"),
            ("threshold", "<f8"),
            ("impurity", "<f8"),
            ("n_node_samples", "<i8"),
            ("weighted_n_node_samples", "<f8"),
            ("missing_go_to_left", "u1")
        ],
    )
    values = np.array(values)

    state = {
        "n_features_": n_features_in,
        "max_depth": max_depth,
        "node_count": node_count,
        "nodes": nodes,
        "values": values,
    }
    tree.__setstate__(state)
    return tree


def export_tree(
    tree,
    filename,
    feature_names,
    action_names,
    integer_features=None,
    colors=None,
    fontname="helvetica",
    continuous_actions=False,
):
    """
    Visualizes the decision tree and exports it using graphviz.
    """
    dot_string = sklearn_tree_to_graphviz(
        tree,
        feature_names,
        action_names,
        integer_features,
        colors,
        fontname,
        continuous_actions,
    )
    graph = pydot.graph_from_dot_data(dot_string)[0]

    if filename.endswith(".png"):
        graph.write_png(filename)
    elif filename.endswith(".pdf"):
        graph.write_pdf(filename)
    elif filename.endswith(".dot"):
        graph.write_dot(filename)
    else:
        raise ValueError(f"Unkown file extension {filename.split('.')[-1]}")


In [None]:

def sklearn_tree_to_graphviz(
    tree,
    feature_names,
    action_names,
    integer_features=None,
    colors=None,
    fontname="helvetica",
    continuous_actions=False,
):
    # If no features are specified as integer then assume they are continuous.
    # this means that if you have integers and don't specify it splits will
    # be printed as <= 4.500 instead of <= 4
    if integer_features is None:
        integer_features = [False for _ in range(len(feature_names))]

    # If no colors are defined then create a default palette
    if colors is None:
        # Seaborn color blind palette
        palette = [
            "#0173b2",
            "#de8f05",
            "#029e73",
            "#d55e00",
            "#cc78bc",
            "#ca9161",
            "#fbafe4",
            "#949494",
            "#ece133",
            "#56b4e9",
        ]
        if continuous_actions:
            colors = palette
        else:
            colors = []
            for i in range(len(action_names)):
                colors.append(palette[i % len(palette)])

    header = f"""digraph Tree {{
node [shape=box, style=\"filled, rounded\", color=\"gray\", fillcolor=\"white\" fontname=\"{fontname}\"] ;
edge [fontname=\"{fontname}\"] ;
"""

    feature = tree.feature
    threshold = tree.threshold
    children_left = tree.children_left
    children_right = tree.children_right
    value = tree.value

    def sklearn_tree_to_graphviz_rec(node_id=0):
        left_id = children_left[node_id]
        right_id = children_right[node_id]
        if left_id != right_id:
            left_dot = sklearn_tree_to_graphviz_rec(left_id)
            right_dot = sklearn_tree_to_graphviz_rec(right_id)

            if node_id == 0:
                edge_label_left = "yes"
                edge_label_right = "no"
            else:
                edge_label_left = ""
                edge_label_right = ""

            feature_i = feature[node_id]
            threshold_value = threshold[node_id]

            feature_name = feature_names[feature_i]

            if integer_features[feature_i]:
                split_condition = int(threshold_value)
            else:
                split_condition = f"{threshold_value:.3f}"

            predicate = (
                f'{node_id} [label="if {feature_name} <= {split_condition}"] ;\n'
            )
            yes = left_id
            no = right_id

            edge_left = (
                f'{node_id} -> {yes} [label="{edge_label_left}", fontcolor="gray"] ;\n'
            )
            edge_right = (
                f'{node_id} -> {no} [label="{edge_label_right}", fontcolor="gray"] ;\n'
            )

            return f"{predicate}{left_dot}{right_dot}{edge_left}{edge_right}"

        if continuous_actions:
            label = ", ".join(f"{x[0]:.2f}" for x in value[node_id])
            color = colors[0]
            return f'{node_id} [label="{label}", fillcolor="{color}", color="{color}", fontcolor=white] ;\n'

        action_i = np.argmax(value[node_id])
        label = f"{action_names[action_i]}"
        color = colors[action_i]
        return f'{node_id} [label="{label}", fillcolor="{color}", color="{color}", fontcolor=white] ;\n'

    body = sklearn_tree_to_graphviz_rec()

    footer = "}"

    return header + body.strip() + footer

In [None]:
#### values for demo purposes
env_name = "CartPole-v1"

#### the best_params for the final decision tree trained by DTPO
#### the keys of this dict are from the DTPO DecisionTreePolicy implementation: https://github.com/tudelft-cda-lab/DTPO/blob/main/dtpo/dtpo.py#L257
best_params = {'params': {'features': np.array([ 3,  2,  2, -2,  2,  2,  3,  3,  3,  0, -2,  2,  2, -2,  1, -2, -2,
       -2, -2, -2, -2,  2,  0, -2, -2, -2,  2, -2, -2, -2, -2]), 'thresholds': np.array([ 1.3415609e-02,  1.8931221e-02, -4.4121705e-02, -2.0000000e+00,
        1.9445941e-01, -1.7292500e-01, -1.2987262e-01, -5.2921700e-01,
       -1.3231094e+00, -7.7489749e-02, -2.0000000e+00, -1.1993183e-01,
       -5.4093454e-02, -2.0000000e+00,  1.4171259e+00, -2.0000000e+00,
       -2.0000000e+00, -2.0000000e+00, -2.0000000e+00, -2.0000000e+00,
       -2.0000000e+00,  1.6916471e-03, -1.3157675e-01, -2.0000000e+00,
       -2.0000000e+00, -2.0000000e+00, -1.0693744e-01, -2.0000000e+00,
       -2.0000000e+00, -2.0000000e+00, -2.0000000e+00]), 'children_left': np.array([ 1,  5,  3, -1, 21,  7, 19,  9, 11, 13, -1, 25, 17, -1, 15, -1, -1,
       -1, -1, -1, -1, 29, 23, -1, -1, -1, 27, -1, -1, -1, -1]), 'children_right': np.array([ 2,  6,  4, -1, 22,  8, 20, 10, 12, 14, -1, 26, 18, -1, 16, -1, -1,
       -1, -1, -1, -1, 30, 24, -1, -1, -1, 28, -1, -1, -1, -1]), 'leaf_logits': np.array([[-0.7023319 , -0.6840421 ],
       [-0.62846714, -0.7622975 ],
       [-0.79016674, -0.60470986],
       [-0.6604398 , -0.7269565 ],
       [-0.8074505 , -0.5905756 ],
       [-0.6114613 , -0.7820996 ],
       [-0.70477957, -0.6816447 ],
       [-0.74220663, -0.64637864],
       [-0.5997462 , -0.79617465],
       [-0.8327886 , -0.5706341 ],
       [-0.43410772, -1.043671  ],
       [-0.7115638 , -0.6750596 ],
       [-0.5893445 , -0.8089819 ],
       [-0.29663196, -1.3599076 ],
       [-0.8627936 , -0.5481461 ],
       [-0.8938114 , -0.52609414],
       [-0.29232538, -1.3724844 ],
       [-0.5520163 , -0.85751647],
       [-0.6164308 , -0.7762372 ],
       [-0.6671731 , -0.7198099 ],
       [-0.81143004, -0.5873834 ],
       [-0.81324255, -0.5859371 ],
       [-0.6420122 , -0.74703425],
       [-0.88283926, -0.5337615 ],
       [-0.5336269 , -0.88303006],
       [-0.8158626 , -0.5838546 ],
       [-0.6181537 , -0.7742197 ],
       [-0.3887487 , -1.1329013 ],
       [-0.70033467, -0.68600696],
       [-0.75799584, -0.63224524],
       [-0.83012414, -0.5726889 ]])}}


discretized_tree = get_discretized_tree(best_params,
                     4, # random_observations.shape[1] #### this is the # values returned in https://github.com/tudelft-cda-lab/DTPO/blob/main/dtpo/dtpo.py#L97 (from doing a random rollout)
                     2, # env.num_actions
                     prune=True)

env_to_feature_action_names = {
    "Pendulum-v1": (["cos theta", "sin theta", "theta dot"], ["left", "right"]),
    "MountainCar-v0": (["position", "velocity"], ["left", "don't accelerate", "right"]),
    "MountainCarContinuous-v0": (["position", "velocity"], ["force"]),
    "CartPole-v1": (
        ["cart position", "cart velocity", "pole angle", "pole angular velocity"],
        ["left", "right"],
    ),
    "Acrobot-v1": (
        [
            "cos joint 1",
            "sin joint 1",
            "cos joint 2",
            "sin joint 2",
            "velocity 1",
            "velocity 2",
        ],
        ["left torque", "no torque", "right torque"],
    ),
}

#### env is of type gymnax.environments.environment.Environment
#### so feel free to replace with your env

# if hasattr(env, "feature_names") and hasattr(env, "action_names"):
#     feature_names = env.feature_names
#     action_names = env.action_names
if env_name in env_to_feature_action_names:
    feature_names, action_names = env_to_feature_action_names[env_name]
# else:
#     if isinstance(env, gymnax.environments.environment.Environment):
#         n_features = env.observation_space(env_params).shape[0]
#     else:
#         n_features = env.observation_space.shape[0]

#     n_actions = env.num_actions

#     feature_names = [f"feature_{i}" for i in range(n_features)]
#     action_names = [f"action_{i}" for i in range(n_actions)]

filename = f"discretized_tree"
export_tree(
    discretized_tree,
    filename + ".dot",
    feature_names,
    action_names,
)
export_tree(
    discretized_tree,
    filename + ".pdf",
    feature_names,
    action_names,
)
export_tree(
    discretized_tree,
    filename + ".png",
    feature_names,
    action_names,
)