# Model understanding

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/yggdrasil-decision-forests/blob/main/documentation/public/docs/tutorial/model_understanding.ipynb)

## Setup

In [None]:
pip install ydf -U

In [1]:
import ydf
import pandas as pd

## What is model understanding?

**Model understanding** or **model interpretation** aims to explain models. Those methods can be used to identify modeling and data issues, and to improve decision-making.

**Model understanding** is different from **prediction understanding** which aims to explain one model prediction. To explain predictions, see the [prediction understanding](../prediction_understanding) tutorial instead.

In YDF, model understanding can be done in 3 ways:

-  `model.describe`: This method shows the input features of the models, variable importances, training logs, individual trees, and other model specific information.
- `model.analyze`: This method runs and displays a deep analysis of the model on a given dataset. For instance, this method shows more variable importances (including SHAP values) and partial dependence plots. 
- `model.print_tree()` and `model.plot_tree()`: Those methods print/plot the actual decision trees.


**Important:** While you can look at the decision trees individually, it is not a good way to interpret models unless your model only contains a single tree. For models with multiple trees like Random Forest and Gradient Boosted Trees, using other methods (notably `model.analyze`) is recommended.

**Information:** Counter-factual analysis is a complementary and more complex way to understand a model. See the [counterfactual notebook](../counterfactual) for details.

## Gathering dataset and training model

In [4]:
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")

# Print the first 5 training examples
train_ds.head(5)

# Train a model
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)

Train model on 22792 examples
Model trained in 0:00:01.089565


## Model description

The **model description** contains two key pieces of information for understanding the model:

**Variable importances**: Variable importances show which features are most important to the model. A variable with a high score is more useful for the model to make its predictions than a variable with a low score.

There exist different measures of variable importances. For instance, the "num_nodes" indicates how many of the nodes in the decision trees use a particular feature.

Another type of variable importance is "mean {decrease,increase} in [metric name]". These show how much the quality of the model would decrease if the feature were removed (in practice, removing a feature is expensive, so the feature is shuffled instead).
For example, a feature with a "mean decrease in accuracy" of 0.1 means that removing this feature would reduce the accuracy of the model by ~0.1.

The variable importances measures shown by YDF complement each other; no single measure can give a full understanding of the model.

**Structure**: Decision forest models are made up of decision trees. The *structure* tab of the model description shows the first decision tree in the model. This can be helpful for understanding how the model is making predictions in general.


In [5]:
model.describe()

## Model analysis

In contrast to the model description, the **model analysis** requires a test dataset. The most informative results of the analysis are [partial dependence plots (PDP)](https://christophm.github.io/interpretable-ml-book/pdp.html), which show the model's prediction marginalized according to each feature value. The model analysis also shows variable importances computed on the provided dataset.

**Note:** Model analysis is computationally expensive. On large datasets, you can use the sampling parameter to run the analysis more quickly on a random subset of the data.

In [6]:
model.analyze(test_ds, sampling=0.1)

## Training and interpret a weaker model

To maximize the quality of your predictions, you likely trained a random forest or a gradient-boosted trees model. While these models are powerful, they can be hard to interpret. For example, only the first tree in a gradient-boosted tree is easily interpretable. To understand these complex models, you can train a simpler model that exhibits similar behavior but with lower performance and easier interpretability. A decision tree, for instance, is a weaker model with a simpler tree structure that can provide valuable insights.

In [8]:
weak_model = ydf.DecisionTreeLearner(label="income", max_depth=5, min_examples=2).train(train_ds)

# Great to understand the leaf values. For instance, `value=[0.2, 0.8]` indicates that the probability
# of the class label is 20%, and the probability of the second one is 80%.
print("class labels:", weak_model.label_classes())

print(weak_model.print_tree()) # Print the tree
# Or
weak_model.plot_tree() # Plot the tree

Train model on 22792 examples
Model trained in 0:00:00.020898
class labels: ['<=50K', '>50K']
'relationship' in ['Husband', 'Wife'] [score=0.10683 missing=False]
    ├─(pos)─ 'education_num' >= 12.5 [score=0.070882 missing=False]
    │        ├─(pos)─ value=[0.2672196177425171, 0.7327803822574829]
    │        └─(neg)─ 'capital_gain' >= 5095.5 [score=0.047968 missing=False]
    │                 ├─(pos)─ value=[0.026402640264026403, 0.9735973597359736]
    │                 └─(neg)─ value=[0.7032752159035359, 0.2967247840964641]
    └─(neg)─ 'capital_gain' >= 7073.5 [score=0.04532 missing=False]
             ├─(pos)─ value=[0.04020100502512563, 0.9597989949748744]
             └─(neg)─ value=[0.950544015825915, 0.04945598417408507]
None


## Counterfactual clusters

Counterfactual examples are the training examples that are most similar to a prediction according to a model. Examining clusters of counterfactual examples can provide insight into how the model sees and segments the examples.

For more information, see the standalone [counterfactual notebook](../counterfactual).