# Example distance

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/yggdrasil-decision-forests/blob/main/documentation/public/docs/tutorial/example_distance.ipynb)

## Setup


In [None]:
pip install ydf scikit-learn plotly -U

In [2]:
import ydf  # Yggdrasil Decision Forests
import pandas as pd  # We use Pandas to load small datasets
import numpy as np

## What is an example distance?

Decision forest models define an **implicit measure of proximity or similarity between two examples**, referred to as **distance**. The distance represents how two examples are treated similarly in the model. Informally, **two examples are close if they are of the same class and for the same reasons**.

This distance is useful for understanding models and their predictions. For example, we can use it for clustering, manifold learning, or simply to look at the training examples that are nearest to a test example. This can help us to understand why the model made its predictions.

Keep in mind that a decision forest's distance measure is just one of many reasonable distance metrics on a dataset. One of its many advantages is that allows comparing features on different scales and with different semantics. 

In this notebook, we will train a model and use its distance to:

- Find training examples that are neighbors of a test example and use them to explain the model's predictions.

- Map all the examples onto an interactive two-dimensional plot (also known as a 2D manifold) and automatically detect two-dimensional clusters of examples that behave similarly.

- Apply hierarchical clustering to explain how the model works as a whole.

**The More You Know:** [Leo Breiman](https://en.wikipedia.org/wiki/Leo_Breiman), the author of the [random forest](https://developers.google.com/machine-learning/glossary#random-forest) learning algorithm, [proposed](https://www.stat.berkeley.edu/~breiman/RandomForests/cc_home.htm#prox) a method to measure the *proximity* between two examples using a pre-trained Random Forest (RF) model. He qualifies this method as <i>"[...] one of the most useful tools in random forests."</i>. When using Random Forest models, this is the distance used by YDF.


## Find closest training examples to a test example

Let's download a classification dataset.

In [3]:
ds_path = "https://raw.githubusercontent.com/google/yggdrasil-decision-forests/main/yggdrasil_decision_forests/test_data/dataset"
train_ds = pd.read_csv(f"{ds_path}/adult_train.csv")
test_ds = pd.read_csv(f"{ds_path}/adult_test.csv")

# Print the first 5 training examples
train_ds.head(5)

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
0,44,Private,228057,7th-8th,4,Married-civ-spouse,Machine-op-inspct,Wife,White,Female,0,0,40,Dominican-Republic,<=50K
1,20,Private,299047,Some-college,10,Never-married,Other-service,Not-in-family,White,Female,0,0,20,United-States,<=50K
2,40,Private,342164,HS-grad,9,Separated,Adm-clerical,Unmarried,White,Female,0,0,37,United-States,<=50K
3,30,Private,361742,Some-college,10,Married-civ-spouse,Exec-managerial,Husband,White,Male,0,0,50,United-States,<=50K
4,67,Self-emp-inc,171564,HS-grad,9,Married-civ-spouse,Prof-specialty,Wife,White,Female,20051,0,30,England,>50K


We train a random forest on this dataset.

In [4]:
model = ydf.RandomForestLearner(label="income").train(train_ds)

Train model on 22792 examples
Model trained in 0:00:01.064247


We need to select a example to explain. Let's select the first example of the testing dataset.

In [5]:
selected_example = test_ds[:1]
selected_example

Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K


On this example, the model predicts:

In [6]:
model.predict(selected_example)

array([0.01], dtype=float32)

In other words, the negative class `<=50K` with $1-0.01=99\%$ probability.

Now, we compute the distance between the selected test example and all the training examples.

In [7]:
distances = model.distance(train_ds, selected_example).squeeze()

print("distances:",distances)

distances: [1.         1.         1.         ... 0.99333334 0.99666667 1.        ]


Let's find the the five training examples with smallest distance to our chosen example.

In [8]:
close_train_idxs = np.argsort(distances)[:5]
print("close_train_idxs:",close_train_idxs)

print("Selected test examples:")
train_ds.iloc[close_train_idxs]

close_train_idxs: [16596 21845 10321  7299 14721]
Selected test examples:


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
16596,41,State-gov,26892,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,0,0,40,United-States,<=50K
21845,37,State-gov,60227,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,0,0,38,United-States,<=50K
10321,40,Private,82161,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,0,0,40,United-States,<=50K
7299,30,State-gov,158291,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,0,0,40,United-States,<=50K
14721,32,State-gov,171111,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,0,0,37,United-States,<=50K


**Observations:**

- For the chosen example, the model predicted class `<=50K`. For the five closes examples, the model had the same prediction.
- The closest examples share many features values, such as `education`, `marital status`, `occupation`, `race`, and working between 37 and 40 `hours per week`. This explains well why these examples are close to each other.
- The examples' `age`s range between 30 and 40, meaning the model sees this age range as equivalent for those examples.


## Two dimensional projections of the examples

Our first use of the proximity is to project the examples on the two dimensional plane. For that, we use [t-SNE](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding).

In [9]:
from sklearn.manifold import TSNE # For 2d projections
from plotly.offline import iplot # For interactive plots
import plotly.graph_objs as go

In [10]:
# Pairwise distance between all testing examples
distances = model.distance(test_ds, test_ds)

In [11]:
# Find 2d projection
t_sne = TSNE(
    # Number of dimensions to display. 3d is also possible.
    n_components=2,
    # Control the shape of the projection. Higher values create more
    # distinct but also more collapsed clusters. Can be in 5-50.
    perplexity=20,
    metric="precomputed",
    init="random",
    verbose=1,
    learning_rate="auto").fit_transform(distances)

[t-SNE] Computing 61 nearest neighbors...
[t-SNE] Indexed 9769 samples in 0.059s...
[t-SNE] Computed neighbors for 9769 samples in 0.798s...
[t-SNE] Computed conditional probabilities for sample 1000 / 9769
[t-SNE] Computed conditional probabilities for sample 2000 / 9769
[t-SNE] Computed conditional probabilities for sample 3000 / 9769
[t-SNE] Computed conditional probabilities for sample 4000 / 9769
[t-SNE] Computed conditional probabilities for sample 5000 / 9769
[t-SNE] Computed conditional probabilities for sample 6000 / 9769
[t-SNE] Computed conditional probabilities for sample 7000 / 9769
[t-SNE] Computed conditional probabilities for sample 8000 / 9769
[t-SNE] Computed conditional probabilities for sample 9000 / 9769
[t-SNE] Computed conditional probabilities for sample 9769 / 9769
[t-SNE] Mean sigma: 0.178857
[t-SNE] KL divergence after 250 iterations with early exaggeration: 75.697197
[t-SNE] KL divergence after 1000 iterations: 1.115830


Let's create an interactive plot with the example features.

In [12]:
def example_to_html(example):
    return "<br>".join([f"<b>{k}:</b> {v}" for k, v in example.items()])


def interactive_plot(dataset, projections):
    colors = (dataset["income"] == ">50K").map(lambda x: ["red", "blue"][x])
    labels = list(dataset.apply(example_to_html, axis=1).values)
    args = {
        "data": [
            go.Scatter(
                x=projections[:, 0],
                y=projections[:, 1],
                text=labels,
                mode="markers",
                marker={"color": colors, "size": 3},
            )
        ],
        "layout": go.Layout(width=500, height=500, template="simple_white"),
    }
    iplot(args)


interactive_plot(test_ds, t_sne)

**Note:** Move your mouse over the plot to see the values of the examples.

The colors represent the labels. We can see clusters of uniform colors (clusters where all the labels are the same), and clusters of mixed colors (clusters where the model has difficulty making good predictions).

Can you make sense of those clusters?


## Cluster examples

We can also cluster examples. [Many methods](https://scikit-learn.org/stable/modules/clustering.html) are available. Let's use `AgglomerativeClustering`. 

In [13]:
from sklearn.cluster import AgglomerativeClustering

num_clusters = 6
clustering = AgglomerativeClustering(
    n_clusters=num_clusters,
    metric="precomputed",
    linkage="average",
).fit(distances)

Next, we print the statistics of the features and one example in each cluster.

In [14]:
import IPython

for cluster_idx in range(num_clusters):
    selected_examples = test_ds[clustering.labels_ == cluster_idx]
    print(f"Cluster #{cluster_idx} with {len(selected_examples)} examples")
    print("=============================")
    IPython.display.display(selected_examples.describe())
    IPython.display.display(selected_examples.iloc[:1])

Cluster #0 with 2879 examples


Unnamed: 0,age,fnlwgt,education_num,capital_gain,capital_loss,hours_per_week
count,2879.0,2879.0,2879.0,2879.0,2879.0,2879.0
mean,42.860021,184706.465439,8.879125,200.963876,32.425842,42.555054
std,12.426582,99424.684674,1.92907,852.256462,231.362238,11.910265
min,18.0,19395.0,1.0,0.0,0.0,1.0
25%,33.0,115465.5,9.0,0.0,0.0,40.0
50%,41.0,176681.0,9.0,0.0,0.0,40.0
75%,51.0,231872.5,10.0,0.0,0.0,46.0
max,90.0,671292.0,12.0,5013.0,2179.0,99.0


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
1,40,Private,121772,Assoc-voc,11,Married-civ-spouse,Craft-repair,Husband,Asian-Pac-Islander,Male,0,0,40,,>50K


Cluster #1 with 5131 examples


Unnamed: 0,age,fnlwgt,education_num,capital_gain,capital_loss,hours_per_week
count,5131.0,5131.0,5131.0,5131.0,5131.0,5131.0
mean,34.026895,193176.8,9.726954,103.289222,57.424479,37.824401
std,13.371512,105519.6,2.395434,642.138022,328.764194,12.40154
min,17.0,19214.0,1.0,0.0,0.0,1.0
25%,23.0,120586.5,9.0,0.0,0.0,35.0
50%,31.0,181721.0,10.0,0.0,0.0,40.0
75%,42.0,241685.5,11.0,0.0,0.0,40.0
max,90.0,1038553.0,16.0,7443.0,3770.0,99.0


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
0,39,State-gov,77516,Bachelors,13,Never-married,Adm-clerical,Not-in-family,White,Male,2174,0,40,United-States,<=50K


Cluster #2 with 220 examples


Unnamed: 0,age,fnlwgt,education_num,capital_gain,capital_loss,hours_per_week
count,220.0,220.0,220.0,220.0,220.0,220.0
mean,44.863636,182932.690909,11.977273,0.0,1996.745455,46.7
std,11.372463,89132.990647,2.314227,0.0,174.160632,11.490357
min,22.0,20953.0,9.0,0.0,1825.0,12.0
25%,36.75,125575.25,10.0,0.0,1887.0,40.0
50%,43.0,169627.5,13.0,0.0,1902.0,41.0
75%,51.0,213384.5,14.0,0.0,1977.0,50.0
max,83.0,530099.0,16.0,0.0,2603.0,99.0


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
25,48,Self-emp-not-inc,191277,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,1902,60,United-States,>50K


Cluster #3 with 1012 examples


Unnamed: 0,age,fnlwgt,education_num,capital_gain,capital_loss,hours_per_week
count,1012.0,1012.0,1012.0,1012.0,1012.0,1012.0
mean,43.610672,186296.0,13.541502,119.073123,14.341897,44.18083
std,11.334174,107433.3,0.874553,675.355176,152.606687,12.260441
min,23.0,22328.0,13.0,0.0,0.0,1.0
25%,35.0,114815.8,13.0,0.0,0.0,40.0
50%,43.0,175648.0,13.0,0.0,0.0,40.0
75%,50.0,230122.8,14.0,0.0,0.0,50.0
max,90.0,1097453.0,16.0,5013.0,1977.0,99.0


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
2,40,Private,193524,Doctorate,16,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,60,United-States,>50K


Cluster #4 with 46 examples


Unnamed: 0,age,fnlwgt,education_num,capital_gain,capital_loss,hours_per_week
count,46.0,46.0,46.0,46.0,46.0,46.0
mean,47.913043,171906.630435,15.543478,280.413043,252.26087,43.021739
std,10.897148,81143.023865,0.50361,1088.008529,679.865949,16.043675
min,32.0,33155.0,15.0,0.0,0.0,6.0
25%,39.0,115998.0,15.0,0.0,0.0,40.0
50%,48.5,163298.0,16.0,0.0,0.0,40.0
75%,53.0,211152.25,16.0,0.0,0.0,50.0
max,79.0,345259.0,16.0,4787.0,2824.0,99.0


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
618,36,Private,103110,Doctorate,16,Never-married,Prof-specialty,Not-in-family,White,Male,0,0,40,England,<=50K


Cluster #5 with 481 examples


Unnamed: 0,age,fnlwgt,education_num,capital_gain,capital_loss,hours_per_week
count,481.0,481.0,481.0,481.0,481.0,481.0
mean,45.621622,191274.322245,11.806653,19103.544699,0.0,46.636175
std,11.010141,103664.004053,2.507912,25872.3371,0.0,11.647901
min,20.0,19302.0,1.0,5178.0,0.0,2.0
25%,38.0,119793.0,10.0,7298.0,0.0,40.0
50%,44.0,175232.0,13.0,10520.0,0.0,45.0
75%,52.0,235786.0,14.0,15024.0,0.0,50.0
max,78.0,617021.0,16.0,99999.0,0.0,99.0


Unnamed: 0,age,workclass,fnlwgt,education,education_num,marital_status,occupation,relationship,race,sex,capital_gain,capital_loss,hours_per_week,native_country,income
21,44,Private,343591,HS-grad,9,Divorced,Craft-repair,Not-in-family,White,Female,14344,0,40,United-States,>50K
