# Understanding Your Model

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/yggdrasil-decision-forests/blob/main/documentation/public/docs/tutorial/model_understanding.ipynb)

## Setup

First, let's install YDF and import the necessary libraries.

In [None]:
pip install ydf -U

In [None]:
import ydf
import pandas as pd

## What is Model Interpretation?

**Model interpretation** (or model understanding) is the process of explaining how a machine learning model works. These methods are crucial for validating and debugging your model, ensuring it aligns with domain knowledge, and building trust in its decisions.

This is different from **prediction interpretation**, which focuses on explaining a *single* prediction. To learn how to explain individual predictions, see the [prediction understanding tutorial](../prediction_understanding).

YDF offers several powerful tools for model interpretation:

* `model.describe()`: Provides a high-level summary of the model, including its input features, variable importances, and training logs.
* `model.analyze()`: Performs a deep analysis of the model's behavior on a dataset, generating rich visualizations like Partial Dependence Plots and SHAP-based variable importances.
* `model.print_tree()` and `model.plot_tree()`: Allows you to visualize the structure of individual decision trees, which is most useful for simple models.


## Dataset and Model Training

We'll use the "Adult" census dataset to predict whether an individual's income is over $50k.

In [None]:
# Load the training and testing datasets
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")

# Display the first few rows
train_ds.head(3)

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
0,44,Private,228057,7th-8th,4,Married-civ-spouse,Machine-op-inspct,Wife,White,Female,0,0,40,Dominican-Republic,<=50K
1,20,Private,299047,Some-college,10,Never-married,Other-service,Not-in-family,White,Female,0,0,20,United-States,<=50K
2,40,Private,342164,HS-grad,9,Separated,Adm-clerical,Unmarried,White,Female,0,0,37,United-States,<=50K


Now, let's train a Gradient Boosted Trees model. This is a powerful model composed of many trees, making it a good candidate for advanced interpretation techniques.

In [None]:
# Train a Gradient Boosted Trees model
model = ydf.GradientBoostedTreesLearner(label="income").train(train_ds)

Train model on 22792 examples
Model trained in 0:00:01.874314


## High-Level Model Summary with `describe()`

The `model.describe()` method is your first stop for understanding a model. It provides a rich, interactive report.

Key sections of the report include:

* **Input Features**: A list of all the features the model was trained on.
* **Variable Importances**: A set of scores that rank features by their importance to the model's predictions. There are several types of importance scores, each giving a different perspective on feature utility. For example, `NUM_NODES` shows how often a feature was used in splits across all trees.

>> Note: `model.describe()` only shows variable importances that can be computed without a test dataset. More advanced variable importances such as SHAP values are available in the model analysis report (see below).

* **Training Logs**: Detailed logs from the training process, which are useful for debugging.


In [None]:
model.describe()

## In-Depth Analysis with `analyze()`

While `describe()` gives a static summary, `model.analyze()` computes how the model behaves on a specific dataset. This is essential for understanding the *relationships* the model has learned.

The analysis generates **Partial Dependence Plots (PDPs)**, which show how a single feature influences the model's predictions on average, while holding all other features constant. It also calculates **additional variable importances**, including those based on SHAP.

> **Note:** Model analysis can be computationally expensive on large datasets. Use the `sampling` parameter to run a faster analysis on a random subset of your data.


In [None]:
# Analyze the model's behavior on the test dataset
# We use a 10% sample for speed.
analysis = model.analyze(test_ds, sampling=0.1)

# The analysis is interactive in a notebook environment.
# You can also save it to a file:
# analysis.to_file("analysis.html")

## Interpreting Simpler Models: The Decision Tree

Powerful models like Random Forests and Gradient Boosted Trees are ensembles of many (often hundreds of) trees.

> **Warning:** It is misleading to interpret a multi-tree model by looking at just one of its trees. Each tree is only a small part of the ensemble, and in GBTs, each tree predicts the *error* of the previous ones, not the final outcome.

To gain intuition about the fundamental structure of your data, it's often helpful to train a simpler, "weaker" model. A single Decision Tree is perfect for this. It will have lower performance but is fully transparent.

Let's train a shallow Decision Tree and visualize it.


In [None]:
# Train a single, shallow decision tree for interpretability
weak_model = ydf.DecisionTreeLearner(label="income", max_depth=4).train(train_ds)

# Get the labels for context: [b'<=50K', b'>50K']
print("Class labels:", weak_model.label_classes())

# You can print the tree as text
print(weak_model.print_tree())

Train model on 22792 examples
Model trained in 0:00:00.718672
Class labels: ['<=50K', '>50K']
'relationship' in ['Husband', 'Wife'] [score=0.10683 missing=False]
    ├─(pos)─ 'education_num' >= 12.5 [score=0.070882 missing=False]
    │        ├─(pos)─ value=[0.2672196177425171, 0.7327803822574829]
    │        └─(neg)─ 'capital_gain' >= 5095.5 [score=0.047968 missing=False]
    │                 ├─(pos)─ value=[0.026402640264026403, 0.9735973597359736]
    │                 └─(neg)─ value=[0.7032752159035359, 0.2967247840964641]
    └─(neg)─ 'capital_gain' >= 7073.5 [score=0.04532 missing=False]
             ├─(pos)─ value=[0.04020100502512563, 0.9597989949748744]
             └─(neg)─ value=[0.950544015825915, 0.04945598417408507]
None


For a classification model, the `value` at each leaf shows the predicted probability for each class. For example, `value=[0.8, 0.2]` means the model predicts the first class with 80% probability and the second with 20%.

A text printout is useful, but a plot is often clearer. `plot_tree()` generates a visual representation where you can trace the decision paths.

In [None]:
# Or plot the tree for a visual representation
weak_model.plot_tree()