# SHAP Analysis for energy models
- Plot SHAP plots for understanding energy prediction models.

In [2]:
import json
import pandas as pd 
import numpy as np
import os 
import shap
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")

In [6]:
# get column names

with open("../src/configs/osda_v1_phys.json", "r") as f:
    o_X_cols = json.load(f).keys()
with open("../src/configs/zeolite_v1_phys_short.json", "r") as f:
    z_X_cols = json.load(f).keys()
Xcols = list(o_X_cols) + list(z_X_cols)

Xcol_dict = {
    'mol_weight': "Molecular weight ($\AA$)",
    'mol_volume': "Molecular volume ($\AA^3$)",
    'asphericity': "Asphericity",
    'eccentricity': "Eccentricity",
    'inertial_shape_factor': "Inertial shape factor",
    'spherocity_index': "Spherocity index",
    'gyration_radius': "Gyration radius ($\AA$)",
    'pmi1': "Principal moment of inertia 1",
    'pmi2': "Principal moment of inertia 2",
    'pmi3': "Principal moment of inertia 3",
    'npr1': "Normalized principal moment of inertia 1",
    'npr2': "Normalized principal moment of inertia 2",
    'free_sasa': "Free solvent-accessible surface area ($\AA^2$)",
    'bertz_ct': "Bertz CT",
    'num_rot_bonds': "Number of rotatable bonds",
    'num_bonds': "Number of bonds",
    'formal_charge': "Formal charge",
    'a': "Lattice vector a ($\AA$)",
    'b': "Lattice vector b ($\AA$)",
    'c': "Lattice vector c ($\AA$)",
    'alpha': "Lattice angle alpha",
    'beta': "Lattice angle beta",
    'gamma': "Lattice angle gamma",
    'num_atoms_per_vol': "Number of atoms per volume ($\AA^{-3}$)",
    'num_atoms': "Number of framework atoms",
    'volume': "Framework volume ($\AA^3$)",
    'largest_free_sphere': "$D_{LFS}$ ($\AA$)",
    'largest_free_sphere_a': "$D_{LFS}$ along a ($\AA$)",
    'largest_free_sphere_b': "$D_{LFS}$ along b ($\AA$)",
    'largest_free_sphere_c': "$D_{LFS}$ along c ($\AA$)",
    'largest_included_sphere': "$D_{LIS}$ ($\AA$)",
    'largest_included_sphere_a': "$D_{LIS}$ along a ($\AA$)",
    'largest_included_sphere_b': "$D_{LIS}$ along b ($\AA$)",
    'largest_included_sphere_c': "$D_{LIS}$ along c ($\AA$)",
    'largest_included_sphere_fsp': "$D_{LIS}$ along free sphere path ($\AA$)",
}

Xcols_pretty = [Xcol_dict[col] for col in Xcols]

In [7]:
# load prior files 

oprior_file = "../../data/datasets/training_data/osda_priors_0.pkl"
zprior_file = "../../data/datasets/training_data/zeolite_priors_0.pkl"

opriors = pd.read_pickle(oprior_file)
zpriors = pd.read_pickle(zprior_file)

In [8]:
# load splits 

split_dir = "../../data/datasets/training_data/splits/1/"
smiles_test = np.load(f"{split_dir}/smiles_test.npy")

In [12]:
op_dir = '../../data/publication/shap/split_1/'
os.makedirs(op_dir, exist_ok=True)

In [15]:
# plot and save beeswarm plots 

fws = ["LTA", "UFI", "RHO", "KFI"]
beeswarm_fig_op_dir = "../../data/publication/shap/split_1_beeswarm/"
os.makedirs(beeswarm_fig_op_dir, exist_ok=True)

for fw in fws:
    file = f"{op_dir}/deep_shap_values_{fw}.csv"
    shap_vals = pd.read_csv(file, index_col=[0,1])

    # get features 

    o_features = opriors.loc[shap_vals.index.get_level_values('SMILES')].reset_index().rename(columns={'index': 'SMILES'})
    z_features = zpriors.loc[shap_vals.index.get_level_values('Zeolite')].reset_index().rename(columns={'index': 'Zeolite'})
    features = pd.concat([o_features, z_features], axis=1)
    features = features.set_index(['SMILES', 'Zeolite'])
    features = features[list(o_X_cols) + list(z_X_cols)]

    # get shap explainer
    explanation = shap.Explanation(
        shap_vals.values,               # shap values 
        feature_names=Xcols_pretty,     # feature names
        data=features                   # features
        )

    # plot 

    shap.plots.beeswarm(
        explanation,
        max_display=7,
        show=False,
        color='viridis_r',
        # cb_aspect=20,
    )
    plt.xlim(-1.5, 2.5)
    plt.annotate(fw, xy=(0.82, 0.2), xycoords='figure fraction', fontsize=15, weight='bold');
    plt.savefig(os.path.join(beeswarm_fig_op_dir, f"{fw}_beeswarm.png"), dpi=300, bbox_inches='tight')
    plt.close()