In [None]:
%matplotlib agg

import pandas as pd
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from models import gen_pred_df, MODELS, X_COLS, TARGET_Y_COLS
MODELS

In [None]:
dataset_df = pd.read_csv('data/dataset.csv')
dataset_df

In [None]:
import shap

X_shap = dataset_df[X_COLS]

In [None]:
explainers = {}
shap_values = {}

import torch

def get_shap(y, yg):
    try:
        if torch.get_num_interop_threads() > 1:
            torch.set_num_interop_threads(1)
        if torch.get_num_threads() > 1:
            torch.set_num_threads(1)
    except:
        pass
        
    f = lambda x: gen_pred_df(x, targets=[yg])[y]
    np.int = int
    np.float = float
    np.bool = bool
    explainer = shap.explainers.Sampling(f, X_shap, seed=0)
    shap_value = explainer(X_shap, nsamples=8192 * 4)
    return y, explainer, shap_value


jobs = []
for yg in TARGET_Y_COLS.keys():
    for y in TARGET_Y_COLS[yg]:
        jobs.append((y, yg))

from joblib import Parallel, delayed
results = Parallel(n_jobs=6, backend='loky')(delayed(get_shap)(*j) for j in jobs)

explainers = {y: e for y, e, v in results}
shap_values = {y: v for y, e, v in results}


### Plot SHAP values for Top 5 and Bottom 5 features

In [None]:
from patched_summary_plot import summary_legacy

plt.close('all')

params = {
   'axes.labelsize': 14,
   'font.size': 14,
   'font.family': "Arial",
   'legend.fontsize': 12,
   'xtick.labelsize': 12,
   'ytick.labelsize': 12,
   'figure.dpi': 300,
}
mpl.rcParams.update(params)
mpl.rcParams['font.size'] = 12
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['mathtext.default'] = 'it'


Y_MATH_SYMS = {'Detachability': 'Detachability',
 'FlatnessUni': 'Flatness',
 'Feasibility': 'Feasibility',
 'TensileStrength': '$\sigma_{u}$',
 'TensileStrain': '$\epsilon_f$', 
 'TensileModulusLog10': '$E$',
 'TensileSED': '$g_{se}$',
 'TransVis': '$T_{Vis}$',
 'TransIR': '$T_{IR}$',
 'TransUV': '$T_{UV}$',
 'FireRR': '$RR$'
}

import copy

for group_name, ys in TARGET_Y_COLS.items():
    for y in ys:
        y_sym = Y_MATH_SYMS[y]
        plt.close('all')

        sub_shap_values = copy.deepcopy(shap_values[y])

        vmax = np.abs(sub_shap_values.values).max()
        sub_shap_values.values = sub_shap_values.values / vmax
        sub_shap_values.base_values = sub_shap_values.base_values / vmax

        shap_df = pd.DataFrame(sub_shap_values.values, columns=sub_shap_values.feature_names)
        pos_5 = shap_df[shap_df > 0.01].mean().sort_values(ascending=False)[:5]
        pos_5_ids = [sub_shap_values.feature_names.index(c) for c in reversed(pos_5.index)]
        
        neg_5 = shap_df[shap_df < -0.01].mean().sort_values(ascending=True)[:5]
        neg_5_ids = [sub_shap_values.feature_names.index(c) for c in reversed(neg_5.index)]

        plt.close('all')
        summary_legacy(sub_shap_values, show=False, sort=True, max_display=5, color_bar_label='Material Composition (%)',
                       feature_order=list(reversed(pos_5_ids + neg_5_ids))
                      )
        plt.xlabel(f'SHAP Value on {y_sym}')
        plt.gcf().savefig(f'shap.{group_name}.{y}.val-order.top-5-bot-5.pdf')