In [1]:
import pickle
import pandas as pd
import sklearn
from sklearn.utils import resample
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from comparers import LogisticRegressionVsXGBComparer, make_path, X, y, display_data
import numpy as np
import shap
from IPython.display import Markdown, display
def printmd(string):
    display(Markdown(string))
notebook_name = 'Explain Differences of Log Odds with SHAP'

In [2]:
%matplotlib inline

In [3]:
comparer = LogisticRegressionVsXGBComparer()
comparer.load_or_train()

Loaded model: models/LogisticRegressionModel.pickle
Loaded model: models/XGBModel.json


Reduce dataset to be explained from 32561 to 1000 samples:

In [4]:
X, y, display_data = resample(X, y, display_data, n_samples=1000, replace=False, stratify=y, random_state=0)

In [5]:
masker = shap.maskers.Independent(X, max_samples=100)

In [6]:
def generate_explanation(model, filename):
    explainer = shap.Explainer(model, masker)
    shap_values = explainer(X)
    shap_values.display_data = display_data
    with open(filename, 'wb') as f:
        pickle.dump((explainer, shap_values), f)
    return explainer, shap_values
        
def load_or_generate_explanation(model, filename):
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except FileNotFoundError:
        return generate_explanation(model, filename)

In [7]:
explainer, shap_values = load_or_generate_explanation(
    lambda X: comparer.model_b.predict_logit(X) - comparer.model_a.predict_logit(X),
    make_path(f'{notebook_name}.shap'))

Permutation explainer: 1001it [01:06, 13.32it/s]                          


PicklingError: Can't pickle <function <lambda> at 0x12714b790>: attribute lookup <lambda> on __main__ failed

## Scatter plots
They are generalizations of the partial dependence plots.

With the parameter color, we instruct SHAP to pick the feature that has the biggest interaction effects with the explained feature and to colour the observations according to its' shapeley values.

In [None]:
def scatter_plot(feature_name):
    printmd(f'### {feature_name}')
    shap.plots.scatter(shap_values[:,feature_name], color=shap_values)

In [None]:
for feature_name in X.columns:
    scatter_plot(feature_name)

Notice the strong interaction effect in the plots of model A. Logistic regression is better explained using the log of odds instead of probabilities.

## Bar plots
> By default a SHAP bar plot will take the mean absolute value of each feature over all the instances (rows) of the dataset.

In [None]:
shap.plots.bar(shap_values, max_display=len(X.columns))

In [None]:
shap.plots.bar(shap_values.percentile(25, 0), max_display=len(X.columns))

In [None]:
shap.plots.bar(shap_values.percentile(75, 0), max_display=len(X.columns))

> But the mean absolute value is not the only way to create a global measure of feature importance, we can use any number of transforms. Here we show how using the max absolute value highights the Capital Gain and Capital Loss features, since they have infrewuent but high magnitude effects.

In [None]:
shap.plots.bar(shap_values.abs.max(0), max_display=len(X.columns))

But by using the 95th percentile it is more robust against outliers

In [None]:
shap.plots.bar(shap_values.abs.percentile(95, 0), max_display=len(X.columns))

We can also compare feature importance for subsets separately, like for men and women:

In [None]:
sex = pd.DataFrame(display_data, columns=X.columns)['Sex'].str.strip().to_list()
shap.plots.bar(shap_values.cohorts(sex).abs.mean(0), max_display=len(X.columns))

In [None]:
sex = pd.DataFrame(display_data, columns=X.columns)['Sex'].str.strip().to_list()
shap.plots.bar(shap_values.cohorts(sex).mean(0), max_display=len(X.columns))

TODO:
- use any algorithm to determine interesting subsets
- make a local barplot with instances just of this subset

## Beeswarm plots
They show the Shapeley value distribution per feature.
> If we are willing to deal with a bit more complexity we can use a beeswarm plot to summarize the entire distribution of SHAP values for each feature.

In [None]:
shap.plots.beeswarm(shap_values)

> By taking the absolute value and using a solid color we get a compromise between the complexity of the bar plot and the full beeswarm plot. Note that the bar plots above are just summary statistics from the values shown in the beeswarm plots below.

In [None]:
shap.plots.beeswarm(shap_values.abs, color="shap_red")

Or we  could simplify the visualization of the distribution by using violin plots:

In [None]:
shap.plots.violin(shap_values.values, features=X)

## Forceplot

In [None]:
shap.initjs()

In [None]:
shap.plots.force(
    base_value=shap_values.abs.mean(0).base_values,
    shap_values=shap_values.values,
    features=display_data,
    feature_names=X.columns)

## Plot embeddings
Use the SHAP values as an embedding which we project to 2D for visualization, using PCA.

In [None]:
shap.plots.embedding('Relationship', shap_values.values, feature_names=X.columns)

## Heatmap
Here we use a supervised hierarchical clustering method to visualize the SHAP values

In [None]:
shap.plots.heatmap(shap_values, max_display=len(X.columns))