# Lab assignment: getting explanations from ensemble models

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/madridAutomata.jpg">

<div style="float: right;">(MadriD: Automata, picture by <a href=https://www.deviantart.com/albarji/art/MadriD-Automata-721505521>Albarji</a>)</div>

In this assignment we will perform learn how to obtain explanations from ensemble models, that allows us the get some intuition of how the model is making its decisions. We will make use of the inspection methods available in scikit-learn, together with the explanations library <a href=https://github.com/slundberg/shap#citations>SHAP</a>.

## Guidelines

Throughout this notebook you will find empty cells that you will need to fill with your own code. Follow the instructions in the notebook and pay special attention to the following symbols.

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/question.png" height="80" width="80" style="float: right;"/>

***

<font color=#ad3e26>
You will need to solve a question by writing your own code or answer in the cell immediately below or in a different file, as instructed.</font>

***

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/exclamation.png" height="80" width="80" style="float: right;"/>

***
<font color=#2655ad>
This is a hint or useful observation that can help you solve this assignment. You should pay attention to these hints to better understand the assignment.
</font>

***

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/pro.png" height="80" width="80" style="float: right;"/>

***
<font color=#259b4c>
This is an advanced exercise that can help you gain a deeper knowledge into the topic. Good luck!</font>

***

To avoid missing packages and compatibility issues you should run this notebook under one of the [recommended Ensembles environment files](https://github.com/albarji/teaching-environments-ensembles).

The following code will embed any plots into the notebook instead of generating a new window:

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

Lastly, if you need any help on the usage of a Python function you can place the writing cursor over its name and press Shift+Tab to produce a pop-out with related documentation. This will only work inside code cells. 

Let's go!

## Data loading

For this exercise we will work the [Adult Dataset](https://archive.ics.uci.edu/ml/datasets/Adult). This dataset is made of 1994 census data of the United States of America. The objective of the dataset is to predict whether each individual earns more than 50K$/year, using demographic information. This dataset is readily available as part of the SHAP library, both in a form amenable for scikit-learn models and in its original form.

In [None]:
import shap
shap.initjs()

X,y = shap.datasets.adult()
X_display,y_display = shap.datasets.adult(display=True)

Since we are going to test several model interpretation methods, let's add a random column to simulate a useless feature.

In [None]:
import numpy as np

rng = np.random.RandomState(seed=42)
X['Random Number'] = X_display['Random Number'] = rng.randn(X.shape[0])

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test, X_display_train, X_display_test = train_test_split(X, y, X_display, test_size=0.2, random_state=7)

## Random Forest feature importance (Mean Decrease in Impurity)

Let's start with the interpretation methods already bundled in scikit-learn. Tree-based methods can compute the importance of each feature as the reduction in impurity they obtain when they are used in the tree. Similarly an ensemble of trees can estimate the imporance of a feature by computing the mean impurity decrease among trees.

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/question.png" height="80" width="80" style="float: right;"/>

***

<font color=#ad3e26>
Train a RandomForest model over the training data. You can use the default parameters. Name the trained model <b>rf</b>.
</font>

***

In [None]:
####### INSERT YOUR CODE HERE

Let's check that the model is able to obtain reasonable accuracy on the train and test data.

In [None]:
print("RF train accuracy: %0.3f" % rf.score(X_train, y_train))
print("RF test accuracy: %0.3f" % rf.score(X_test, y_test))

Since we are not using any pruning strategies, most likely your model has overfitted the data. But we will nevertheless use this model and see how different inspection methods behave.

To obtain feature importances we can use the `feature_importances_` attribute of the model.

In [None]:
rf.feature_importances_

The importances are an array with the Mean Decrease in Impurity of each feature in the whole forest. For easier understanding we can create a barplot showing these importances.

In [None]:
tree_feature_importances = rf.feature_importances_
sorted_idx = tree_feature_importances.argsort()

feature_names = X_train.columns
y_ticks = np.arange(0, len(feature_names))
fig, ax = plt.subplots()
ax.barh(y_ticks, tree_feature_importances[sorted_idx])
ax.set_yticklabels(feature_names[sorted_idx])
ax.set_yticks(y_ticks)
ax.set_title("Random Forest Feature Importances (MDI)")
fig.tight_layout()
plt.show()

Note how the most important feature is *Random Number*, even though this is a useless feature! This is because this feature can perfectly split all the training datapoints, and thus produces the largest decrease in impurity. Basing our model in a random feature is clearly overfitting, but since the Mean Decrease in Impurity is computed over the data used in the model training, it is impossible to know this is a bad feature. That is the main reason why this method does not provide a good metric to assess feature relevance. Let's move on to a better approach!

## Permutation importances

A more suitable and general method to measuring relevances is Permutation Importances. To analyze the relevance of a feature, the column with the values of such feature is rearranged through a random permutation, and the loss in model performance is understood has the relevance of such feature. We can perform this analysis over the test data to obtain less biased interpretations. Also, the permutation can be repeated a number of times (*n_repeats*) to obtain more accurate relevance estimates.

In [None]:
from sklearn.inspection import permutation_importance

result = permutation_importance(rf, X_test, y_test, n_repeats=10, random_state=42, n_jobs=8)
result

Permutation Importances returns the estimated importance obtained for each feature and repetition (permutation), along with pooled means and standard deviations of importances for each feature. A good way to visualize these is by making use of boxplots:

In [None]:
sorted_idx = result.importances_mean.argsort()

fig, ax = plt.subplots()
ax.boxplot(result.importances[sorted_idx].T,
           vert=False, labels=X_test.columns[sorted_idx])
ax.set_title("Permutation Importances (test set)")
fig.tight_layout()
plt.show()

The Permutation Importance analysis reveals how the *Random Number* feature we introduced is not relevant at all, while other, more reasonable features such as *Capital Gain* are now very relevant.

## Partial dependence plots

A more fine-grained analysis of the influence of each feature on the target can be obtained through **partial dependence plots**. These plots show how changes along the value of each feature change the probability of predicting a positive class in the model. For instance, to compute this plot for the first feature in the data (*Age*) we just need to run

In [None]:
from sklearn.inspection import plot_partial_dependence

column = X_test.columns.get_loc('Age')
plot_partial_dependence(rf, X_test, features=[column])

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/question.png" height="80" width="80" style="float: right;"/>

***

<font color=#ad3e26>
Generate partial dependences plot for the random feature we introduced, as well as for another relevant feature.</font>

***

In [None]:
####### INSERT YOUR CODE HERE

## SHAP local explanations

A non-obvious problem with the approaches above is that these explanation methods **do not have consistency** when trying to explain the relevance of each feature for an specific data point. Let's suppose we have two models: A and B. Model A has larger changes in its predictions than B if a certain feature $x_i$ is removed. However, if our explanation method is non-consistent, it might happen that the computed feature relevances show a larger relevance of feature $x_i$ in model B than in model A!

The SHAP library makes use of the [Shapley values](https://en.wikipedia.org/wiki/Shapley_value) from game theory to guarantee consistency, among other useful properties such as local accuracy and missingness. It can also be shown mathematically that Shapely values are the only ones that meet all of these properties simultaneously.

SHAP integrates well with many ensemble method libraries. For instance, it can efficiently obtain accurate explanations for Extreme Gradient Boosting models.

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/question.png" height="80" width="80" style="float: right;"/>

***

<font color=#ad3e26>
     Train an Extreme Gradient Boosting model over the training data. You can use the default parameters. Name the trained model <b>xgb</b>.

***

In [None]:
####### INSERT YOUR CODE HERE

The SHAP library includes several explaning methods for different kind of models. For an ensemble the **TreeExplainer** is the best choice.

In [None]:
explainer = shap.Explainer(xgb)

With this explainer we can obtain SHAP values for every feature in every data point, as follows:

In [None]:
explanation = explainer(X_test)
explanation

The returned object is of type `Explanation`, and contains three elements:
* The SHAP `values` for each data point
* The `base_values`, that is, the $\phi_0$ bias coefficient for each data point. This represents the model prediction when no features at all are used.
* The original `data` points we are explaining.

An `Explanation` can be indexed as if it were an array. So for instance, we can get the information about the first data point as

In [None]:
explanation[0]

These values must be interpreted as follows: by default the prediction of the model will be the `base_value`. However, after analyzing each of the values of the input features, this default prediction is impacted by each one the SHAP `values`. Positive SHAP values mean the corresponding feature increases the probability of predicting positive class, while negative values reduce this probability. This can be visualized with a force plot:

In [None]:
shap.plots.force(explanation[0])

The plot shows as `base value` the default value predicted by the model, and as <font color="red">red</font> and <font color="blue">blue</font> the features that increase or decrease the output value. The final **output value** shown in bold is the prediction made by the model after considering all the features. This representation is useful to understand why the model decided to produce its classification.

Alternatively, we can also use a waterfall plot to represent the same information

In [None]:
shap.plots.waterfall(explanation[0])

## SHAP feature relevances

By grouping all the computed SHAP values by features and drawing a scatterplot we can visualize the global relevance of each feature in the model:

In [None]:
shap.plots.beeswarm(explanation)

## SHAP dependence plots

We can also obtain partial dependence plots from SHAP values. For instance, we can plot again how the *Age* influences the probabilities of high-earnings

In [None]:
shap.plots.scatter(explanation[:,"Age"], color=explanation)

The plot shows a scatterplot with a dot for each datapoint in the test set, organized by their *Age* and corresponding SHAP value. To make the plot easier to understand, SHAP automatically selects another highly discriminative feature and uses it to color the dots. This provides more information than a scikit-learn partial dependence plot.

<img src="https://albarji-labs-materials.s3-eu-west-1.amazonaws.com/question.png" height="80" width="80" style="float: right;"/>

***

<font color=#ad3e26>
     Generate SHAP partial dependences plot for the random feature we introduced, as well as for the other relevant feature you plotted in the partial dependencies exercise above. Can you spot differences between scikit-learn and SHAP plots?

***

In [None]:
####### INSERT YOUR CODE HERE