## Interpretability - Tabular SHAP explainer

In this example, we use Kernel SHAP to explain a tabular classification model built from the Adults Census dataset.

First we import the packages and define some UDFs we will need later.

In [0]:
import pyspark
from mmlspark.explainers import *
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler
from pyspark.sql.types import *
from pyspark.sql.functions import *
import pandas as pd

vec_access=udf(lambda v, i:float(v[i]),FloatType())
vec2array = udf(lambda vec: vec.toArray().tolist(), ArrayType(FloatType()))

Now let's read the data and train a simple binary classification model.

In [0]:
df = spark.read.parquet("wasbs://publicwasb@mmlspark.blob.core.windows.net/AdultCensusIncome.parquet")

labelIndexer = StringIndexer(inputCol="income",outputCol="label", stringOrderType="alphabetAsc").fit(df)
print ("Label index assigment: " + str(set(zip(labelIndexer.labels, [0,1]))))

training = labelIndexer.transform(df)
display(training)
categorical_features = ["workclass", "education", "marital-status", "occupation", "relationship", "race", "sex", "native-country"]
categorical_features_idx = [col + "_idx" for col in categorical_features]
categorical_features_enc = [col + "_enc" for col in categorical_features]
numeric_features = ["age", "education-num", "capital-gain", "capital-loss", "hours-per-week"]

strIndexer = StringIndexer(inputCols=categorical_features, outputCols=categorical_features_idx)
onehotEnc = OneHotEncoder(inputCols=categorical_features_idx, outputCols=categorical_features_enc)
vectAssem = VectorAssembler(inputCols=categorical_features_enc + numeric_features, outputCol="features")
lr = LogisticRegression(featuresCol="features", labelCol="label", weightCol="fnlwgt")
pipeline = Pipeline(stages=[strIndexer, onehotEnc, vectAssem, lr])
model = pipeline.fit(training)

age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income,label
45,Private,362883,HS-grad,9,Married-civ-spouse,Craft-repair,Husband,White,Male,5013,0,40.0,United-States,<=50K,0.0
43,Private,182757,HS-grad,9,Married-civ-spouse,Craft-repair,Husband,White,Male,0,0,40.0,United-States,>50K,1.0
20,Private,50397,HS-grad,9,Never-married,Adm-clerical,Own-child,Black,Male,0,0,20.0,United-States,<=50K,0.0
43,Federal-gov,101709,Some-college,10,Divorced,Handlers-cleaners,Not-in-family,Asian-Pac-Islander,Male,0,0,40.0,United-States,<=50K,0.0
21,Private,202570,12th,8,Never-married,Adm-clerical,Other-relative,Black,Male,0,0,48.0,?,<=50K,0.0
40,Private,145649,HS-grad,9,Separated,Sales,Unmarried,Black,Female,0,0,25.0,United-States,<=50K,0.0
36,Private,136343,Bachelors,13,Married-civ-spouse,Prof-specialty,Husband,White,Male,0,0,40.0,United-States,>50K,1.0
64,Self-emp-inc,142166,Bachelors,13,Divorced,Sales,Not-in-family,White,Male,0,0,45.0,United-States,<=50K,0.0
19,?,242001,Some-college,10,Never-married,?,Own-child,White,Female,0,0,40.0,United-States,<=50K,0.0
46,Private,127089,Some-college,10,Married-civ-spouse,Prof-specialty,Husband,White,Male,5178,0,38.0,United-States,>50K,1.0


After the model is trained, we randomly select some observations to be explained.

In [0]:
explain_instances = model.transform(training).orderBy(rand()).limit(5).repartition(200).cache()
display(explain_instances)

age,workclass,fnlwgt,education,education-num,marital-status,occupation,relationship,race,sex,capital-gain,capital-loss,hours-per-week,native-country,income,label,native-country_idx,race_idx,occupation_idx,marital-status_idx,workclass_idx,relationship_idx,education_idx,sex_idx,sex_enc,workclass_enc,relationship_enc,race_enc,education_enc,native-country_enc,occupation_enc,marital-status_enc,features,rawPrediction,probability,prediction
26,Private,280093,Some-college,10,Never-married,Handlers-cleaners,Own-child,White,Male,0,0,40.0,United-States,<=50K,0.0,0.0,0.0,9.0,1.0,0.0,2.0,1.0,0.0,"Map(vectorType -> sparse, length -> 1, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 8, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 5, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 4, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 15, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 41, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 14, indices -> List(9), values -> List(1.0))","Map(vectorType -> sparse, length -> 6, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 99, indices -> List(0, 9, 24, 38, 45, 48, 52, 53, 94, 95, 98), values -> List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 26.0, 10.0, 40.0))","Map(vectorType -> dense, length -> 2, values -> List(5.295233373488859, -5.295233373488859))","Map(vectorType -> dense, length -> 2, values -> List(0.9950095853766281, 0.004990414623371908))",0.0
50,Self-emp-inc,163921,HS-grad,9,Married-civ-spouse,Sales,Husband,White,Male,0,0,48.0,United-States,>50K,1.0,0.0,0.0,4.0,0.0,5.0,0.0,0.0,0.0,"Map(vectorType -> sparse, length -> 1, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 8, indices -> List(5), values -> List(1.0))","Map(vectorType -> sparse, length -> 5, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 4, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 15, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 41, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 14, indices -> List(4), values -> List(1.0))","Map(vectorType -> sparse, length -> 6, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 99, indices -> List(5, 8, 23, 33, 43, 48, 52, 53, 94, 95, 98), values -> List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 50.0, 9.0, 48.0))","Map(vectorType -> dense, length -> 2, values -> List(0.16037282720802892, -0.16037282720802892))","Map(vectorType -> dense, length -> 2, values -> List(0.5400074959907509, 0.45999250400924907))",0.0
34,?,24504,HS-grad,9,Never-married,?,Not-in-family,White,Male,0,0,40.0,United-States,<=50K,0.0,0.0,0.0,7.0,1.0,3.0,1.0,0.0,0.0,"Map(vectorType -> sparse, length -> 1, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 8, indices -> List(3), values -> List(1.0))","Map(vectorType -> sparse, length -> 5, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 4, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 15, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 41, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 14, indices -> List(7), values -> List(1.0))","Map(vectorType -> sparse, length -> 6, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 99, indices -> List(3, 8, 24, 36, 44, 48, 52, 53, 94, 95, 98), values -> List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 34.0, 9.0, 40.0))","Map(vectorType -> dense, length -> 2, values -> List(4.185333108161747, -4.185333108161747))","Map(vectorType -> dense, length -> 2, values -> List(0.9850109542052289, 0.014989045794771116))",0.0
40,Private,168071,Assoc-acdm,12,Divorced,Prof-specialty,Not-in-family,White,Male,0,0,50.0,United-States,<=50K,0.0,0.0,0.0,0.0,2.0,0.0,1.0,6.0,0.0,"Map(vectorType -> sparse, length -> 1, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 8, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 5, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 4, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 15, indices -> List(6), values -> List(1.0))","Map(vectorType -> sparse, length -> 41, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 14, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 6, indices -> List(2), values -> List(1.0))","Map(vectorType -> sparse, length -> 99, indices -> List(0, 14, 25, 29, 44, 48, 52, 53, 94, 95, 98), values -> List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 40.0, 12.0, 50.0))","Map(vectorType -> dense, length -> 2, values -> List(1.424067541793566, -1.424067541793566))","Map(vectorType -> dense, length -> 2, values -> List(0.8059752867803827, 0.19402471321961734))",0.0
34,State-gov,61431,12th,8,Never-married,Adm-clerical,Not-in-family,Black,Female,0,0,40.0,United-States,<=50K,0.0,0.0,1.0,3.0,1.0,4.0,1.0,11.0,1.0,"Map(vectorType -> sparse, length -> 1, indices -> List(), values -> List())","Map(vectorType -> sparse, length -> 8, indices -> List(4), values -> List(1.0))","Map(vectorType -> sparse, length -> 5, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 4, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 15, indices -> List(11), values -> List(1.0))","Map(vectorType -> sparse, length -> 41, indices -> List(0), values -> List(1.0))","Map(vectorType -> sparse, length -> 14, indices -> List(3), values -> List(1.0))","Map(vectorType -> sparse, length -> 6, indices -> List(1), values -> List(1.0))","Map(vectorType -> sparse, length -> 99, indices -> List(4, 19, 24, 32, 44, 49, 53, 94, 95, 98), values -> List(1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 34.0, 8.0, 40.0))","Map(vectorType -> dense, length -> 2, values -> List(5.021090526377174, -5.021090526377174))","Map(vectorType -> dense, length -> 2, values -> List(0.9934459112501833, 0.006554088749816667))",0.0


We create a TabularSHAP explainer, set the input columns to all the features the model takes, specify the model and the target output column we are trying to explain. In this case, we are trying to explain the "probability" output which is a vector of length 2, and we are only looking at class 1 probability. Specify targetClasses to `[0, 1]` if you want to explain class 0 and 1 probability at the same time. Finally we sample 100 rows from the training data for background data, which is used for integrating out features in Kernel SHAP.

In [0]:
shap = TabularSHAP(
    inputCols=categorical_features + numeric_features,
    outputCol="shapValues",
    numSamples=5000,
    model=model,
    targetCol="probability",
    targetClasses=[1],
    backgroundData = training.orderBy(rand()).limit(100).cache()
)

shap_df = shap.transform(explain_instances)


Once we have the resulting dataframe, we extract the class 1 probability of the model output, the SHAP values for the target class, the original features and the true label. Then we convert it to a pandas dataframe for visisualization.
For each observation, the first element in the SHAP values vector is the base value (the mean output of the background dataset), and each of the following element is the SHAP values for each feature.

In [0]:
shaps = shap_df.withColumn("probability", vec_access(col("probability"), lit(1))) \
               .withColumn("shapValues", vec2array(col("shapValues").getItem(0))) \
               .select(["shapValues", "probability", "label"] + categorical_features + numeric_features)

shaps_local = shaps.toPandas()
shaps_local.sort_values("probability", ascending=False, inplace=True, ignore_index=True)
pd.set_option('display.max_colwidth', None)
shaps_local

Unnamed: 0,shapValues,probability,label,workclass,education,marital-status,occupation,relationship,race,sex,native-country,age,education-num,capital-gain,capital-loss,hours-per-week
0,"[0.21021755, 0.036331102, 0.011407855, 0.18172061, 0.039502986, -0.047949668, 0.0033071144, 0.03732453, 0.011162714, 0.037939783, -0.06340959, -0.01767566, -0.008496406, 0.028608594]",0.459992,1.0,Self-emp-inc,HS-grad,Married-civ-spouse,Sales,Husband,White,Male,United-States,50,9,0,0,48.0
1,"[0.21021883, 0.0034574957, -0.044558, -0.16193147, 0.047279127, 0.05946539, 0.002313033, 0.024235502, 0.0066087, -0.0020558704, 0.056080315, -0.02217087, -0.0067982804, 0.021881111]",0.194025,0.0,Private,Assoc-acdm,Divorced,Prof-specialty,Not-in-family,White,Male,United-States,40,12,0,0,50.0
2,"[0.2102186, 0.04852501, 0.006356042, -0.12734048, -0.09719812, 0.03810386, 0.0027049792, 0.011517685, 0.004807474, -0.0142795425, -0.03725752, -0.020403938, -0.0029868914, -0.00777806]",0.014989,0.0,?,HS-grad,Never-married,?,Not-in-family,White,Male,United-States,34,9,0,0,40.0
3,"[0.21021777, -0.014681821, 0.012115142, -0.103285804, -0.008804426, 0.032152824, -0.0064477073, -0.03024225, 0.004430357, -0.009691941, -0.048759993, -0.021556646, -0.002236587, -0.006655595]",0.006554,0.0,State-gov,12th,Never-married,Adm-clerical,Not-in-family,Black,Female,United-States,34,8,0,0,40.0
4,"[0.2102176, 0.00085895706, 0.010095258, -0.09567099, -0.03868283, -0.030505812, 0.0016090366, 0.0092368545, 0.0045897323, -0.019838404, -0.016353566, -0.022075318, -0.002163665, -0.006327383]",0.00499,0.0,Private,Some-college,Never-married,Handlers-cleaners,Own-child,White,Male,United-States,26,10,0,0,40.0


We use plotly subplot to visualize the SHAP values.

In [0]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import pandas as pd

features = categorical_features + numeric_features
features_with_base = ["Base"] + features

rows = shaps_local.shape[0]

fig = make_subplots(rows=rows, cols=1, subplot_titles = "Probability: " + shaps_local["probability"].apply("{:.2%}".format) + "; Label: " + shaps_local["label"].astype(str))

for index, row in shaps_local.iterrows():
  feature_values = [0] + [row[feature] for feature in features]
  shap_values = row["shapValues"]
  list_of_tuples = list(zip(features_with_base, feature_values, shap_values))
  shap_pdf = pd.DataFrame(list_of_tuples, columns = ['name', 'value', 'shap'])
  fig.add_trace(
    go.Bar(x = shap_pdf["name"], y = shap_pdf["shap"], hovertext= 'value: ' + shap_pdf["value"].astype(str)),
    row=index + 1, col=1
  )
  
fig.update_yaxes(range=[-1, 1], fixedrange=True, zerolinecolor='black')
fig.update_xaxes(type='category', tickangle=45, fixedrange=True)
fig.update_layout(height=400 * rows, title_text="SHAP explanations")
fig.show()
