# Inspecting trees
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/yggdrasil-decision-forests/blob/main/documentation/public/docs/tutorial/inspecting_trees.ipynb)


## Setup

In [None]:
pip install ydf -U

In [4]:
import ydf
import numpy as np

## What does it mean to inspect trees?

A decision forest model, such as Random Forest or Gradient Boosted Decision Trees, is a collection of decision trees. A decision tree has "internal nodes" (i.e. nodes with child nodes) and "leaf nodes". Using the `get_tree` and `print_tree` methods, you can inspect the structure of the trees, the conditions and the leaf values.

In this notebook, we train a simple CART model on a synthetic dataset and inspect its tree structure.


## Synthetic dataset

Our dataset is composed of two input features and six examples.

In [5]:
dataset = {
    "x1": np.array([0, 0, 0, 1, 1, 1]),
    "x2": np.array([1, 1, 0, 0, 1, 1]),
    "y": np.array([0, 0, 0, 0, 1, 1]),
}

dataset

{'x1': array([0, 0, 0, 1, 1, 1]),
 'x2': array([1, 1, 0, 0, 1, 1]),
 'y': array([0, 0, 0, 0, 1, 1])}

## Training a model

In [8]:
model = ydf.CartLearner(label="y", min_examples=1, task=ydf.Task.REGRESSION).train(dataset)

model.describe()

Train model on 6 examples
Model trained in 0:00:00.000728


## Plotting the model

The tree of the model is visible in the "structure" tab of `model.describe()`. You can also print trees with the `print_tree` method.

In [9]:
model.print_tree()

'x1' >= 0.5 [score=0.11111 missing=True]
    ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True]
    │        ├─(pos)─ value=1 sd=0
    │        └─(neg)─ value=0 sd=0
    └─(neg)─ value=0 sd=0


## Accessing the tree structure

The `get_tree` and `get_all_trees` methods give access the structure of the trees programmatically.

**Note:** A CART model only has one tree, so the `tree_idx` argument is set to `0`. For models with multiple trees, the number of trees is available with `model.num_trees()`.

In [11]:
tree = model.get_tree(tree_idx=0)

tree

Tree(root=NonLeaf(value=RegressionValue(num_examples=6.0, value=0.3333333432674408, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.1111111119389534, attribute=1, threshold=0.5), pos_child=NonLeaf(value=RegressionValue(num_examples=3.0, value=0.6666666865348816, standard_deviation=0.4714045207910317), condition=NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5), pos_child=Leaf(value=RegressionValue(num_examples=2.0, value=1.0, standard_deviation=0.0)), neg_child=Leaf(value=RegressionValue(num_examples=1.0, value=0.0, standard_deviation=0.0))), neg_child=Leaf(value=RegressionValue(num_examples=3.0, value=0.0, standard_deviation=0.0))))

Do you recognize the structure of the tree printed above? You can access parts of the tree. For example, you can access the condition on `x2`:

In [12]:
tree.root.pos_child.condition

NumericalHigherThanCondition(missing=True, score=0.2222222238779068, attribute=2, threshold=0.5)

To show the tree in a more readable form, you can use  the `pretty` function.

In [14]:
print(tree.pretty(model.data_spec()))

'x1' >= 0.5 [score=0.11111 missing=True]
    ├─(pos)─ 'x2' >= 0.5 [score=0.22222 missing=True]
    │        ├─(pos)─ value=1 sd=0
    │        └─(neg)─ value=0 sd=0
    └─(neg)─ value=0 sd=0

