In [None]:
import math
import os

import joblib
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import PIL.Image
import seaborn as sns
import shap
from mordred import Calculator, descriptors
from rdkit import Chem
from sklearn.preprocessing import MinMaxScaler, RobustScaler

In [None]:
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=UserWarning)

In [None]:
# Plot styling.
plt.style.use(['seaborn-white', 'seaborn-paper'])
plt.rc('font', family='sans-serif')
sns.set_palette(['#6da7de', '#9e0059', '#dee000', '#d82222', '#5ea15d',
                 '#943fa6', '#63c5b5', '#ff38ba', '#eb861e', '#ee266d'])
sns.set_context('paper', font_scale=1.3)

In [None]:
np.random.seed(42)

## Generate features

In [None]:
mordred_calculator = Calculator(descriptors, ignore_3D=True)

In [None]:
# Get the original feature labels used during training.
compounds = pd.read_csv('../data/processed/compound_smiles.csv')
mols = compounds['smiles'].apply(Chem.MolFromSmiles)
features_train = pd.DataFrame(mordred_calculator.pandas(mols)
                              .select_dtypes(exclude='object')
                              .astype(np.float32))
feature_labels = features_train.columns
clss = compounds['skin'].astype(np.int32)

In [None]:
approved_drugs = pd.read_csv('../data/processed/fda.csv')
mols = approved_drugs['smiles'].apply(Chem.MolFromSmiles)
features_approved_drugs = (pd.DataFrame(mordred_calculator.pandas(mols)
                                        [feature_labels]
                                        .astype(np.float32)))

In [None]:
biotransformations = pd.read_csv(
    '../data/processed/fda_biotransformations.csv')
mols = biotransformations['smiles'].apply(Chem.MolFromSmiles)
features_biotransformations = (pd.DataFrame(mordred_calculator.pandas(mols)
                                            [feature_labels]
                                            .astype(np.float32)))

## SHAP feature importances

In [None]:
# Feature importances.
classifier = joblib.load('../data/processed/rf.joblib')
predict_proba = lambda x: classifier.predict_proba(x)[:,1]
explainer = shap.KernelExplainer(
    predict_proba, shap.kmeans(features_train.values, 50),
    model_output='probability')

In [None]:
shap_train = explainer.shap_values(features_train)

In [None]:
shap_approved_drugs = explainer.shap_values(features_approved_drugs)

In [None]:
shap_biotransformations = explainer.shap_values(features_biotransformations)

In [None]:
_ = joblib.dump((shap_train, shap_approved_drugs, shap_biotransformations),
                '../data/processed/feature_importance.joblib')

## Plotting

In [None]:
def rand_jitter(arr):
    return arr + np.random.randn(len(arr)) * .075


width = 7
height = width / 1.618    # Golden ratio.
fig, ax = plt.subplots(figsize=(height, width))

importances = np.abs(np.mean(shap_train, axis=0))
order = np.argsort(importances)[::-1]
n_features = 20
for i in range(n_features):
    c = features_train[feature_labels[order[i]]].values.reshape(-1, 1)
    c = RobustScaler().fit_transform(c)
    c = MinMaxScaler().fit_transform(c)
    sc = ax.scatter(
        shap_train[:, order[i]],
        rand_jitter(np.repeat(n_features - i, c.shape[0])),
        c=c,
        marker='.',
        alpha=0.5,
        cmap='viridis',
        zorder=10)
ax.axvline(0, c='lightgray')

cbar_ax = fig.add_axes([0.95, 0.2, 0.025, 0.6])
colorbar = fig.colorbar(sc, ticks=[0, 1], cax=cbar_ax)
colorbar.solids.set(alpha=1)
colorbar.ax.set_yticklabels(['Low', 'High'])
colorbar.set_label('Feature value', labelpad=-15)

ax.set_xlabel('SHAP feature importance')
ax.set_yticks(np.arange(1, n_features + 1))
ax.set_yticklabels(feature_labels[order][:n_features][::-1])

sns.despine()

plt.savefig('feature_importance_train.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
width = 7
height = width / 1.618    # Golden ratio.
fig, ax = plt.subplots(figsize=(width, height))

importances = np.abs(np.mean(shap_approved_drugs, axis=0))
order = np.argsort(importances)[::-1]
sns.barplot(x=np.arange(20), y=importances[order][:20], color='#6da7de')

ax.set_xticklabels(feature_labels[order][:20], rotation=90)
ax.set_ylabel('SHAP feature importance')

sns.despine()

plt.savefig('feature_importance_fda.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

In [None]:
width = 7
height = width / 1.618    # Golden ratio.
fig, ax = plt.subplots(figsize=(width, height))

importances = np.abs(np.mean(shap_biotransformations, axis=0))
order = np.argsort(importances)[::-1]
sns.barplot(x=np.arange(20), y=importances[order][:20], color='#6da7de')

ax.set_xticklabels(feature_labels[order][:20], rotation=90)
ax.set_ylabel('SHAP feature importance')

sns.despine()

plt.savefig('feature_importance_biotransformations.png', dpi=300,
            bbox_inches='tight')
plt.show()
plt.close()

In [None]:
width = 7
height = width / 1.618    # Golden ratio.
fig, ax = plt.subplots(figsize=(width, height))

feature = 'ATSC7v'
column_i = features_train.columns.get_loc(feature)
ax.scatter(features_train[feature], shap_train[:, column_i])

ax.set_xlabel(feature)
ax.set_ylabel('SHAP value')

sns.despine()

plt.savefig(f'feature_importance_{feature}.png', dpi=300, bbox_inches='tight')
plt.show()
plt.close()

## Prediction case studies

In [None]:
test_compounds = pd.DataFrame(
    [('Diphenhydramine', 'CN(CCOC(c1ccccc1)c1ccccc1)C'),
     ('Diphenhydramine N-hexose', 'O[C@H]([C@H]([C@@H]([C@H](O1)CO)O)O)C1[N+](C)(C)CCOC(C2=CC=CC=C2)C3=CC=CC=C3'),
     ('Citalopram', 'CN(C)CCCC1(C2=C(CO1)C=C(C=C2)C#N)C3=CC=C(C=C3)F'),
     ('Tacrolimus', 'CC1CC(C2C(CC(C(O2)(C(=O)C(=O)N3CCCCC3C(=O)OC(C(C(CC(=O)C(C=C(C1)C)CC=C)O)C)C(=CC4CCC(C(C4)OC)O)C)O)C)OC)OC')],
    columns=['compound_name', 'smiles'])
test_features = pd.DataFrame(
    mordred_calculator.pandas(test_compounds['smiles']
                              .apply(Chem.MolFromSmiles),
                              nproc=1, quiet=True)
    [features_train.columns].astype(np.float32))

for i, (compound, contribution_threshold, label) in enumerate(
        zip(test_compounds['compound_name'], (0.02, 0.03, 0.02, 0.03), 'ABCD')):
    # Train the classifier and SHAP without the test compound.
    idx = compounds[compounds['compound_name'] != compound].index.values
    classifier = joblib.load('../data/processed/rf.joblib').fit(
        features_train.values[idx], clss[idx])
    predict_proba = lambda x: classifier.predict_proba(x)[:, 1]
    explainer = shap.KernelExplainer(
        predict_proba, shap.kmeans(features_train.values[idx], 50),
        model_output='probability')
    # Get prediction and feature importances for the test compound.
    test_pred = classifier.predict_proba(
        test_features.iloc[i].values.reshape(1, -1))[0, 1]
    test_shap = explainer.shap_values(test_features.iloc[i])
    explanation = shap.Explanation(test_shap, explainer.expected_value,
                                   feature_names=test_features.columns)
    # Create a force plot using SHAP.
    with sns.plotting_context('paper', font_scale=2):
        shap.plots.force(
            explainer.expected_value, test_shap,
            test_features.iloc[i].round(2).astype(str),
            plot_cmap=sns.color_palette(['#EE266D', '#00AEEF'], as_cmap=True),
            matplotlib=True, show=False, text_rotation=90,
            contribution_threshold=contribution_threshold)
    # Modify the force plot because it has annoying hard-coded settings.
    # Update text font size.
    ax = plt.gca()
    for child in ax.get_children():
        if isinstance(child, matplotlib.text.Text):
            child.set_fontsize(20)
            if child.get_text() in ('higher', 'lower',
                                    '$\\leftarrow$', '$\\rightarrow$'):
                child._y += 0.1
            elif child.get_text() == 'f(x)':
                child._y += 0.08
            elif child.get_text() == f'{test_pred:.2f}':
                child._y += 0.05
    # Add title and subfigure label.
    ax.set_title(compound, {'fontsize': 30, 'fontweight': 'bold'}, pad=120)
    ax.annotate(label, xy=(-0.05, 1.75), xycoords='axes fraction',
                fontsize=36, weight='bold')
    plt.savefig(f'shap_{compound}.png', dpi=300, bbox_inches='tight')
    plt.close()
# Combine all figures into a single PNG.
imgs = [np.asarray(PIL.Image.open(f'shap_{compound}.png'))
        for compound in test_compounds['compound_name']]
shape = max([i.shape for i in imgs])
imgs_comb = np.vstack([
    np.pad(
        i,
        ((math.floor((shape[0] - i.shape[0]) / 2),
          math.ceil((shape[0] - i.shape[0]) / 2)),
         (math.floor((shape[1] - i.shape[1]) / 2),
          math.ceil((shape[1] - i.shape[1]) / 2)),
         (0, 0)),
        constant_values=(255,))
    for i in imgs])
PIL.Image.fromarray(imgs_comb).save('shap_compounds.png')
for compound in test_compounds['compound_name']:
    os.remove(f'shap_{compound}.png')