# SHAP-Style Plots  Documentation

This notebook is designed to demonstrate how to use `pymint` to plot [SHAP-style](https://github.com/slundberg/shap) plots within MintPy. For more information on dataset and initializing `InterpretToolkit`, see the permutatation importance notebook.

In [1]:
import sys, os 
current_dir = os.getcwd()
path = os.path.dirname(current_dir)
sys.path.append(path)
import numpy as np
import xarray as xr
import pandas as pd
import shap

In [2]:
import pymint
import plotting_config 

### Loading the training data and pre-fit models 


In [3]:
estimators = pymint.load_models()
X,y = pymint.load_data()

In [4]:
X_subset = shap.sample(X, 200, random_state=22)
explainer = pymint.InterpretToolkit(estimators,X=X_subset,)

### Compute the Shap Values 

In [None]:
background_dataset = shap.sample(X, 100)
results = explainer.shap(background_dataset=background_dataset)

In [None]:
shap_values, bias = results['Random Forest']

## Summary Plot

Once we compute the SHAP values for a large number of examples, we can evaluate different patterns. For example, in the plot below, SHAP values are ranked by their absolute sum. Additionally, the SHAP values are color-coded by their normalized magnitude where red indicates a higher predictor value while blue indicates a lower predictor value. In this case, surface temperature ($T_{sfc}$) had the largest absolute sum and lower values increases the probability of freezing road surface temperatures. 



In [None]:
explainer.plot_shap(
                    plot_type = 'summary',
                    shap_values=shap_values,
                    display_feature_names=plotting_config.display_feature_names,
)                           

Instead of this plot, we can evaluate the SHAP-based ranking with the bar-style plot used in PyMint. 

In [None]:
def shap_values_to_importance(shap_values, estimator_name, feature_names, method='sum'):
    """
    Convert SHAP values into feature importance.
    """
    if method == 'std':
        # Compute the std(SHAP) 
        shap_rank= np.std(shap_values, axis=0)
    elif method == 'sum':
        #Compute sum of abs values
        shap_rank = np.sum(np.absolute(shap_values), axis=0)

    ranked_indices = np.argsort(shap_rank)[::-1]
    scores_ranked = np.array(shap_rank[ranked_indices])
    features_ranked = np.array(feature_names)[ranked_indices]

    data={}
    data[f"shap_rankings__{estimator_name}"] = (
                    [f"n_vars_shap"],
                    features_ranked,
                )
    data[f"shap_scores__{estimator_name}"] = (
                    [f"n_vars_shap", "n_bootstrap"],
                    scores_ranked.reshape(len(scores_ranked),1),
    )
    data = xr.Dataset(data)

    return data

data = shap_values_to_importance(shap_values, estimator_name='Random Forest', feature_names=X.columns)
explainer.plot_importance(data=data, 
                          estimator_names = 'Random Forest', 
                          method='shap')

### Dependence Plot

SHAP values can also be displayed similar to ALE/PD curve where the values are presented as a function of the predictor value. For this plots, we can also include `histdata`, which is a combination of `X` and `y`. For classification problems, the user can provide the name of the target variable and the background histogram will be color-coded for the different classes. Lastly, each dot is color coded by the value of the feature that approximately interacts with the feature being plotted. 


In [None]:
X.columns

In [None]:
features = ['sat_irbt', 'd_rad_d', 'temp2m', 'hrrr_dT']

histdata=X.copy()
histdata['target'] = y

explainer.plot_shap(features=features,
                    plot_type = 'dependence',
                    shap_values=shap_values/100,
                    display_feature_names=plotting_config.display_feature_names,
                    display_units = plotting_config.display_units,
                    histdata=histdata,
                    target='target',
                    interaction_index='auto'
)

## No color-coding of dots 

Set `interaction_index=None` to remove the color-coding. 

In [None]:
features = ['tmp2m_hrs_bl_frez', 'sat_irbt', 'sfcT_hrs_ab_frez', 'tmp2m_hrs_ab_frez', 'd_rad_d', 'temp2m']

explainer.plot_shap(features=features,
                    plot_type = 'dependence',
                    shap_values=shap_values,
                    display_feature_names=plotting_config.display_feature_names,
                    display_units = plotting_config.display_units,
                    histdata=histdata,
                    interaction_index=None,
)

## SHAP for Regression

In [None]:
from sklearn.datasets import fetch_california_housing
from sklearn.ensemble import RandomForestRegressor

In [None]:
data = fetch_california_housing()
X = data['data']
y = data['target']
feature_names = data['feature_names']
model= RandomForestRegressor()
model.fit(X,y)

In [None]:
X_subset = shap.sample(X, 100, random_state=42)
explainer = pymint.InterpretToolkit(('Random Forest', model),X=X_subset, feature_names=feature_names)

In [None]:
background_dataset = shap.sample(X, 100)
results = explainer.shap(background_dataset=background_dataset)
shap_values, bias = results['Random Forest']
explainer.plot_shap(
                    plot_type = 'summary',
                    shap_values=shap_values,
                    display_feature_names=plotting_config.display_feature_names,
)                           