In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
import pickle
import random
import seaborn as sns
from scipy.stats import f_oneway

from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans

### Input

Folder and datset names

In [None]:
folder_name = 'withhyperopt_4'
individual_datasets = ['clinical_noninvasive','objscreen_blood']
individual_datasets_labels = ['Non-Invasive Clinical','Blood Metabolites']
colors = ['#FFD700','#DC143C']
n_splits = 20

Output folder name

In [None]:
output_folder = '_shap'

Analysis options

In [None]:
analyze_stacker = True
analyze_shap_values = True
analyze_shap_interactions = True

### Create output folders

In [None]:
# name of dataset
dataset_name = '+'.join(individual_datasets)

# output folders
if not os.path.isdir('%s/%s/%s' % (folder_name, dataset_name, output_folder)):
    os.mkdir('%s/%s/%s' % (folder_name, dataset_name, output_folder))
    
    # stacker
    if analyze_stacker:
        os.mkdir('%s/%s/%s/stacker' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/stacker/summary' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/stacker/clinical_features' % (folder_name, dataset_name, output_folder))
        
    # shap values
    if analyze_shap_values:
        os.mkdir('%s/%s/%s/shap_values' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/summary' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/summary/stacker' % (folder_name, dataset_name, output_folder))
        for dataset in individual_datasets:
            os.mkdir('%s/%s/%s/shap_values/summary/stacker/%s' % (folder_name, dataset_name, output_folder, dataset))
        os.mkdir('%s/%s/%s/shap_values/summary/cohort' % (folder_name, dataset_name, output_folder))
        for cohort in ['ACC', 'BLCA', 'BRCA', 'CESC', 'COAD', 'DLBC', 'GBM', 'HNSC', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'PRAD', 'READ', 'SKCM', 'STAD', 'THCA', 'UCEC', 'UCS']:
            os.mkdir('%s/%s/%s/shap_values/summary/cohort/%s' % (folder_name, dataset_name, output_folder, cohort))
        os.mkdir('%s/%s/%s/shap_values/features' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/patients' % (folder_name, dataset_name, output_folder))
    
    # shap interactions
    if analyze_shap_interactions:
        os.mkdir('%s/%s/%s/shap_interactions' % (folder_name, dataset_name, output_folder))

### Load dataset

In [None]:
# create merged dataset
with open('_datasets/%s.pickle' % individual_datasets[0], 'rb') as f:
    X_matrix, y_vector, categorical_conversion_old = pickle.load(f, encoding='latin1')
X_matrix.columns = ['%s # %s' % (individual_datasets[0], x) for x in X_matrix.columns.tolist()]
categorical_conversion = {}
for key in categorical_conversion_old:
    categorical_conversion['%s # %s' % (individual_datasets[0], key)] = categorical_conversion_old[key]
for b in range(1,len(individual_datasets)):
    with open('_datasets/%s.pickle' % individual_datasets[b], 'rb') as f:
        X_matrix_, y_vector_, categorical_conversion_old = pickle.load(f, encoding='latin1')
    X_matrix_.columns = ['%s # %s' % (individual_datasets[b], x) for x in X_matrix_.columns.tolist()]
    categorical_conversion_ = {}
    for key in categorical_conversion_old:
        categorical_conversion_['%s # %s' % (individual_datasets[b], key)] = categorical_conversion_old[key]
    if X_matrix.index.tolist() == X_matrix_.index.tolist():
        X_matrix = pd.concat([X_matrix, X_matrix_], axis=1)
        categorical_conversion = {**categorical_conversion, **categorical_conversion_}
    else:
        raise Exception("Dataset sample lists don't match")
samples = X_matrix.index.tolist()
features =  X_matrix.columns.tolist()

# list of merged features
if len(categorical_conversion) > 0:
    merged_features = []
    for feature in features:
        if feature.split(' | ')[0] not in categorical_conversion:
            merged_features.append(feature)
        elif feature.split(' | ')[0] not in merged_features:
            merged_features.append(feature.split(' | ')[0])
else:
    merged_features = features.copy()
    
# merge feature matrix
X_matrix_ = X_matrix.copy()
for feature in categorical_conversion:
    X_matrix_[feature] = np.nan
    X_matrix_[feature] = X_matrix_[feature].astype(object)
    
    # get values for categorical features
    values = []
    for sample in X_matrix.index.tolist():
        for val in categorical_conversion[feature]:
            if X_matrix.at[sample,'%s | %s' % (feature, val)] == 1:
                X_matrix_.at[sample, feature] = val
X_matrix = X_matrix_[merged_features]

### Load performance

In [None]:
# stacker - log loss
if analyze_stacker:
    df = pd.read_csv('%s/stacker_logloss.csv' % folder_name, index_col=0)
    stacker_log_loss = df.loc[['split_%d' % (i+1) for i in range(n_splits)]][dataset_name].values.tolist()
    stacker_performance_weights = [1/x for x in stacker_log_loss]

# classifier - weighted log loss
df = pd.read_csv('%s/weightedlogloss.csv' % folder_name, index_col=0)
classifier_weighted_log_loss = df.loc[['split_%d' % (i+1) for i in range(n_splits)]][dataset_name].values.tolist()
classifier_performance_weights = [1/x for x in classifier_weighted_log_loss]

### SHAP values

#### Load data

In [None]:
if analyze_shap_values:
        
    # load shap values
    shap_expected = []
    shap_values = []
    for i in range(n_splits):
        with open('%s/%s/shap_values_%d.pickle' % (folder_name, dataset_name, i+1) ,'rb') as f:
            shap_expected_, shap_values_ = pickle.load(f)
            shap_expected.append(shap_expected_)
            shap_values.append(shap_values_)
            
    # remove samples that aren't analyzed
    keep_index = []
    for i,sample in enumerate(samples):
        to_keep = False
        for j in range(n_splits):
            if sample in shap_values[j].index.tolist():
                to_keep = True
        if to_keep:
            keep_index.append(i)
    samples = [samples[i] for i in keep_index]
    y_vector = [y_vector[i] for i in keep_index]
    X_matrix = X_matrix.iloc[keep_index]
    
    # combine expected values
    _expected_value_sum = [0 for sample in samples]
    _shap_weight = [0 for sample in samples]
    for i,sample in enumerate(samples):
        for j in range(n_splits):
            if sample in shap_values[j].index.tolist():
                _expected_value_sum[i] += classifier_performance_weights[j] * shap_expected[j][shap_values[j].index.tolist().index(sample)]
                _shap_weight[i] += classifier_performance_weights[j]
    expected_value = np.divide(_expected_value_sum, _shap_weight)
    
    # combine shap values
    for i in range(n_splits):
        samples_not_included = [x for x in samples if x not in shap_values[i].index.tolist()]
        shap_values[i] = pd.concat([shap_values[i], pd.DataFrame(data=0, index=samples_not_included, columns=shap_values[i].columns.tolist())])
        shap_values[i] = shap_values[i].loc[samples]
        features_not_included = [x for x in merged_features if x not in shap_values[i].columns.tolist()]
        shap_values[i] = pd.concat([shap_values[i], pd.DataFrame(data=0, index=samples, columns=features_not_included)], axis=1, sort=False)
        shap_values[i] = shap_values[i][merged_features]
    shap_value = pd.DataFrame(data=0, index=samples, columns=merged_features)
    for i in range(n_splits):
        shap_value = np.add(shap_value, classifier_performance_weights[i]*shap_values[i])
    for j in range(len(samples)):
        shap_value.loc[samples[j]] /= _shap_weight[j]
        
    # calculate absolute shap values
    shap_value_abs = shap_value.abs()
    
    # zero shap values for imputed features
    for sample in shap_value_abs.index.tolist():
        for feature in shap_value_abs.columns.tolist():
            if pd.isna(X_matrix.at[sample,feature]):
                shap_value_abs.at[sample,feature] = 0
    
    # normalize by difference between prior and posterior
    for i,sample in enumerate(samples):
        shap_value_abs.loc[sample] /= np.sum(shap_value_abs.loc[sample].values.tolist())
        
    # mean and standard error for each feature
    shap_value_abs_mean = shap_value_abs.mean(axis=0).values.tolist()
    shap_value_abs_mean_cumsum = np.cumsum(sorted(shap_value_abs_mean)[::-1])
    shap_value_abs_std = shap_value_abs.std(axis=0).values.tolist()
    shap_value_abs_sterr = [x/np.sqrt(len(samples)) for x in shap_value_abs_std]

#### All features

In [None]:
if analyze_shap_values:

    # top 95% features
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:len([x for x in shap_value_abs_mean_cumsum if x<=0.95])]
    overall_top_features = [merged_features[i] for i in top_N_index]
    overall_top_values = [shap_value_abs_mean[i] for i in top_N_index]
    
    # top N features - mean
    top_N_index = np.argsort(shap_value_abs_mean)[::-1]
    top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]

    # plot - all shap values
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=18)
    plt.rc('ytick', labelsize=18)
    fig = plt.figure(figsize=(6,5))
    ax = fig.add_subplot(111)
    ax.set_xscale('log')
    ax.plot(range(1,len(merged_features)+1), sorted(shap_value_abs_mean)[::-1], 'k.-', markersize=5)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.ylim(-0.0024,0.24)
    plt.yticks([0,0.04,0.08,0.12,0.16,0.2,0.24])
    plt.xlabel('Features', fontsize=20)
    ax2 = ax.twinx()
    ax2.plot(range(1,len(merged_features)+1), shap_value_abs_mean_cumsum, '.-', color='#808080', markersize=5)
    ax2.plot([len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),len([x for x in shap_value_abs_mean_cumsum if x<=0.95])],[0,0.95],'--',color='#808080')
    ax2.plot([len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),1.1*len(merged_features)],[0.95,0.95],'--',color='#808080')
    ax2.plot([len([x for x in shap_value_abs_mean if x>0]),len([x for x in shap_value_abs_mean if x>0])],[0,1],'--',color='#808080')
    ax2.tick_params(axis='y', labelcolor='#808080')
    ax2.set_ylabel('Cumulative Sum', color='#808080', fontsize=20)
    ax2.set_ylim(0,1.01)
    ax2.set_yticks([0,0.25,0.5,0.75,0.95,1])
    num_int = len(str(len(merged_features)))
    plt.xticks([10**x for x in range(num_int)]+[len(merged_features)],[10**x for x in range(num_int)]+[len(merged_features)])
    plt.xlim(0.95,1.05*len(merged_features))  
    plt.text(len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),0.03,'%d' % len([x for x in shap_value_abs_mean_cumsum if x<=0.95]), weight='bold', ha='right', fontsize=20)
    plt.text(len([x for x in shap_value_abs_mean if x>0]),0.03,'%d' % len([x for x in shap_value_abs_mean if x>0]), ha='right', fontsize=20)
    #plt.text(len(merged_features),0.03,'%d' % len(merged_features), weight='bold', ha='right', fontsize=20)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    vals = ax2.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax2.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax2.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax2.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/all_features.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()

#### Dataset contributions

In [None]:
if analyze_shap_values:
    
    # plot - dataset contribution
    dataset_contribution = []
    for dataset in individual_datasets:
        dataset_contribution.append(np.sum([x for i,x in enumerate(shap_value_abs_mean) if merged_features[i].split(' # ')[0]==dataset and merged_features[i] in overall_top_features]))
    dataset_contribution = [x/np.sum(dataset_contribution) for x in dataset_contribution]
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    patches, texts, autotexts = plt.pie(dataset_contribution, labels=individual_datasets_labels, colors=colors, autopct='%1.1f%%')
    for i in range(len(texts)):
        texts[i].set_fontsize(20)
        autotexts[i].set_fontsize(20)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()
    
    # plot - individual patient dataset contribution
    dataset_contribution_individual = [[] for dataset in individual_datasets]
    dataset_contribution_individual_top = [[] for dataset in individual_datasets]
    for j,sample in enumerate(samples):
        sample_values = shap_value_abs.loc[sample].values.tolist()
        for i,dataset in enumerate(individual_datasets):
            dataset_contribution_individual[i].append(np.sum([sample_values[a] for a in range(len(merged_features)) if merged_features[a].split(' # ')[0]==dataset and merged_features[a] in overall_top_features])/np.sum(sample_values))
        if np.sum([sample_values[a] for a in range(len(merged_features)) if merged_features[a].split(' # ')[0]==dataset])/np.sum(sample_values) > (1/len(individual_datasets)):
            dataset_contribution_individual_top[i].append(sample)
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(3*len(individual_datasets),10))
    ax = fig.add_subplot(111)
    plt.rc('xtick', labelsize=20)
    plt.rc('ytick', labelsize=20)
    fig = sns.boxplot(data=dataset_contribution_individual, whis=1, fliersize=0, palette=colors)
    fig = sns.swarmplot(data=dataset_contribution_individual, color='black', alpha=0.25)
    plt.xlim([-0.5,len(individual_datasets)-0.5])
    fig.set(xticklabels=individual_datasets_labels)
    fig.spines['right'].set_visible(False)
    fig.spines['top'].set_visible(False)
    plt.ylabel('Percent Contribution', fontsize=20)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_individual.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()

#### Clinical Multimodal

In [None]:
# functions
seed_ = 1
def compute_inertia(a, X):
    W = [np.mean(pairwise_distances(X[a == c, :])) for c in np.unique(a)]
    return np.mean(W)

def compute_gap(clustering, data, k_max=5, n_references=100):
    np.random.seed(seed_)
    if len(data.shape) == 1:
        data = data.reshape(-1, 1)
    reference = np.random.rand(*data.shape)
    reference_inertia = []
    for k in range(1, k_max+1):
        local_inertia = []
        for _ in range(n_references):
            clustering.n_clusters = k
            assignments = clustering.fit_predict(reference)
            local_inertia.append(compute_inertia(assignments, reference))
        reference_inertia.append(np.mean(local_inertia))
    
    ondata_inertia = []
    for k in range(1, k_max+1):
        clustering.n_clusters = k
        assignments = clustering.fit_predict(data)
        ondata_inertia.append(compute_inertia(assignments, data))
        
    gap = np.log(reference_inertia)-np.log(ondata_inertia)
    return gap, np.log(reference_inertia), np.log(ondata_inertia)

# calculate optimal number of distrubutions
k_max = 10
gap, reference_inertia, ondata_inertia = compute_gap(KMeans(random_state=seed_), np.array(dataset_contribution_individual[0]).reshape(-1,1), k_max=k_max)
optimal_number_of_distributions = np.argmax(gap)+1
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.plot(range(1,k_max+1), gap, 'k.-', markersize=5)
ax.plot([optimal_number_of_distributions, optimal_number_of_distributions],[0.05,np.max(gap)], 'k--')
plt.xlabel('Number of Distributions', fontsize=20)
plt.ylabel('Gap Statistic', fontsize=20)
plt.xticks([1,2,3,4,5,6,7,8,9,10])
plt.ylim([0.05,0.35])
plt.yticks([.05,.15,.25,.35])
plt.savefig('%s/%s/%s/shap_values/summary/clinical_kmeans_gap.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

# separate samples into distributions
kmeans = KMeans(n_clusters=optimal_number_of_distributions).fit(np.array(dataset_contribution_individual[0]).reshape(-1,1))
labels = kmeans.labels_
centers = kmeans.cluster_centers_.reshape(-1,)
distribution_labels = []
for val in labels:
    distribution_labels.append(sorted(centers).index(centers[val]))
distribution_names = ['low','high']
distribution_names_labels = ['Low','High']

# histogram
values = dataset_contribution_individual[0]
n_bins = 25
_, bin_edges = np.histogram(values, bins=n_bins)
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
colors_ = ['#FFFACD','#F0E68C']
for i in range(optimal_number_of_distributions):
    sns.distplot([values[a] for a in range(len(values)) if distribution_labels[a]==i], hist=True, kde=True, bins=bin_edges, color=colors_[i], hist_kws={'edgecolor':'black'}, kde_kws={'linewidth': 4})
plt.legend(['Low Clinical - %0.1f%%' % (len([x for x in distribution_labels if x==0])/len(distribution_labels)*100),'High Clinical - %0.1f%%' % (len([x for x in distribution_labels if x==1])/len(distribution_labels)*100)], fontsize=16)
plt.xlabel('Percent Contribution - Clinical', fontsize=20)
plt.ylabel('Count', fontsize=20)    
plt.xlim(0,1)
plt.xticks([0,0.25,0.5,0.75,1],['0%','25%','50%','75%','100%'])
plt.savefig('%s/%s/%s/shap_values/summary/clinical_multimodal.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

# dataset contribution - low/med/high clinical
for a in range(len(distribution_names)):
    
    # subset shap values
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = shap_value_abs_distribution.mean(axis=0).values.tolist()
    
    # new top features
    shap_value_abs_distribution_cumsum = np.cumsum(sorted(shap_value_abs_distribution_mean)[::-1])
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:len([x for x in shap_value_abs_distribution_cumsum if x<=0.95])]
    new_overall_top_features = [merged_features[i] for i in top_N_index]
    
    # dataset contribution
    dataset_contribution_distribution = []
    for dataset in individual_datasets:
        dataset_contribution_distribution.append(np.sum([x for i,x in enumerate(shap_value_abs_distribution_mean) if merged_features[i].split(' # ')[0]==dataset and merged_features[i] in new_overall_top_features]))
    dataset_contribution_distribution = [x/np.sum(dataset_contribution_distribution) for x in dataset_contribution_distribution]
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    patches, texts, autotexts = plt.pie(dataset_contribution_distribution, labels=individual_datasets_labels, colors=colors, autopct='%1.1f%%')
    plt.title('%s Clinical' % distribution_names_labels[a], fontsize=20)
    for i in range(len(texts)):
        texts[i].set_fontsize(20)
        autotexts[i].set_fontsize(20)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_clinical_%s.png' % (folder_name, dataset_name, output_folder, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

In [None]:
# dataset contribution - low/med/high clinical
for a in range(len(distribution_names))[0:1]:
    
    # subset shap values
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = shap_value_abs_distribution.mean(axis=0).values.tolist()
    
    # new top features
    shap_value_abs_distribution_cumsum = np.cumsum(sorted(shap_value_abs_distribution_mean)[::-1])
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:len([x for x in shap_value_abs_distribution_cumsum if x<=0.95])]
    new_overall_top_features = [merged_features[i] for i in top_N_index]
    
    # dataset contribution
    contribution_names = ['Other Non-Invasive Clinical','Cancer Type','Blood Metabolites']
    contribution_values = [np.sum([x for i,x in enumerate(shap_value_abs_distribution_mean) if merged_features[i].split(' # ')[0]=='clinical_noninvasive' and merged_features[i] in new_overall_top_features and merged_features[i]!='clinical_noninvasive # COHORT']), shap_value_abs_distribution_mean[merged_features.index('clinical_noninvasive # COHORT')],  np.sum([x for i,x in enumerate(shap_value_abs_distribution_mean) if merged_features[i].split(' # ')[0]=='objscreen_blood' and merged_features[i] in new_overall_top_features])]
    contribution_values = [x/np.sum(contribution_values) for x in contribution_values]
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    patches, texts, autotexts = plt.pie(contribution_values, colors=['#FFD700','#e6c200','#DC143C'], autopct='%1.1f%%')
    plt.title('%s Clinical' % distribution_names_labels[a], fontsize=20)
    for i in range(len(texts)):
        texts[i].set_fontsize(26)
        autotexts[i].set_fontsize(26)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_clinical_%s_cohort.png' % (folder_name, dataset_name, output_folder, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

#### Top features

In [None]:
if analyze_shap_values:
    
    # N value
    N = 50
    
    # top N features - mean
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
    top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    colors_ = []
    for dataset in top_N_datasets:
        colors_.append(colors[individual_datasets.index(dataset)])
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(20,5))
    ax = fig.add_subplot(111)
    barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5)
    for i in range(len(barlist)):
        barlist[i].set_color(colors_[i])
    plt.xticks(range(N), top_N_features, rotation=90)
    plt.xlim(-1,N)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.title('Top %d Features' % N, fontsize=24)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.legend([barlist[colors_.index(x)] for x in colors], individual_datasets_labels, fontsize=16)
    plt.savefig('%s/%s/%s/shap_values/summary/top%d.png' % (folder_name, dataset_name, output_folder, N), bbox_inches='tight', dpi=400)
    plt.close()
    
    

In [None]:
# top N features - mean - low/med/high clinical
for a in range(len(distribution_names))[0:1]:

    # get values
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = shap_value_abs_distribution.mean(axis=0).values.tolist()
    shap_value_abs_distribution_std = shap_value_abs_distribution.std(axis=0).values.tolist()
    shap_value_abs_distribution_sterr = [x/np.sqrt(len([y for y in distribution_labels if y==a])) for x in shap_value_abs_distribution_std]

    # plot
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:N]
    top_N_mean = [shap_value_abs_distribution_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_distribution_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    colors_ = []
    for dataset in top_N_datasets:
        colors_.append(colors[individual_datasets.index(dataset)])
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(20,5))
    ax = fig.add_subplot(111)
    barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5)
    for i in range(len(barlist)):
        barlist[i].set_color(colors_[i])
    plt.xticks(range(N), top_N_features, rotation=90)
    plt.xlim(-1,N)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.title('%s Clinical Patients' % distribution_names[a], fontsize=24)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.legend([barlist[colors_.index(x)] for x in colors], individual_datasets_labels, fontsize=16)
    plt.savefig('%s/%s/%s/shap_values/summary/top%d_clinical_%s.png' % (folder_name, dataset_name, output_folder, N, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

In [None]:
# top features - low clinical - other non-invasive clinical
for a in range(len(distribution_names))[0:1]:

    # get values
    keep_features_id = [i for i,x in enumerate(merged_features) if x.split(' # ')[0]=='clinical_noninvasive' and x!='clinical_noninvasive # COHORT']
    keep_features = [merged_features[i] for i in keep_features_id]
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = [shap_value_abs_distribution.mean(axis=0).values.tolist()[i] for i in keep_features_id]
    shap_value_abs_distribution_std = [shap_value_abs_distribution.std(axis=0).values.tolist()[i] for i in keep_features_id]
    shap_value_abs_distribution_sterr = [x/np.sqrt(len([y for y in distribution_labels if y==a])) for x in shap_value_abs_distribution_std]

    # plot
    N = 5
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:N]
    top_N_mean = [shap_value_abs_distribution_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_distribution_sterr[i] for i in top_N_index]
    top_N_datasets = [keep_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [keep_features[i].split(' # ')[-1] for i in top_N_index]
    labels = ['Tumor location','Clinical Stage','Patient age','Patient race','Patient Gender']
    
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(20/50*N*1.5,5))
    ax = fig.add_subplot(111)
    barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5, color=['#FFD700'])
    plt.xticks(range(N), labels, rotation=90, fontsize=16)
    plt.xlim(-1,N)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.title('Other Non-Invasive Clinical', fontsize=20)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/top%d_clinical_%s_otherclinical.png' % (folder_name, dataset_name, output_folder, N, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

In [None]:
# top features - low clinical - blood metabolies
for a in range(len(distribution_names))[0:1]:

    # get values
    keep_features_id = [i for i,x in enumerate(merged_features) if x.split(' # ')[0]=='objscreen_blood']
    keep_features = [merged_features[i] for i in keep_features_id]
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = [shap_value_abs_distribution.mean(axis=0).values.tolist()[i] for i in keep_features_id]
    shap_value_abs_distribution_std = [shap_value_abs_distribution.std(axis=0).values.tolist()[i] for i in keep_features_id]
    shap_value_abs_distribution_sterr = [x/np.sqrt(len([y for y in distribution_labels if y==a])) for x in shap_value_abs_distribution_std]

    # plot
    N = 11
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:N]
    top_N_mean = [shap_value_abs_distribution_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_distribution_sterr[i] for i in top_N_index]
    top_N_datasets = [keep_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [keep_features[i].split(' # ')[-1] for i in top_N_index]
    labels = ['Butyric acid','Prostaglandin J2','Phenylacetic acid','5Z-Tetradecenoic acid','Aminoadipic acid','Pyroglutamic acid','GDP','Allysine','Capric acid','Prostaglandin D2','Adrenic acid']
    
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(20/50*N*1.5,5))
    ax = fig.add_subplot(111)
    barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5, color=['#DC143C'])
    plt.xticks(range(N), labels, rotation=90, fontsize=16)
    plt.xlim(-1,N)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.title('Blood Metabolites', fontsize=20)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/top%d_clinical_%s_blood.png' % (folder_name, dataset_name, output_folder, N, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

### Shap plots

In [None]:
# plotting function
def plot_patient(patient_name, actual_value, expected_value, features_, shap_, original_):
    
    # initialize figure
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(20,.25*len(features_)))
    plt.xticks([])
    plt.yticks([])
    
    # patient name
    if actual_value == 0:
        plt.text(0.5,0.6,patient_name, horizontalalignment='center', verticalalignment='center', color='#5bc2ae', fontsize=14, weight='bold')
    else:
        plt.text(0.5,0.6,patient_name, horizontalalignment='center', verticalalignment='center', color='#d93f20', fontsize=14, weight='bold')
    
    # number line
    plt.plot([0,1],[0,0],'k-',linewidth=1)
    for x in [0,0.5,1]:
        plt.plot([x,x],[-0.1,0.1],'k-',linewidth=1)
    plt.plot([expected_value,expected_value],[-0.3,0.1],'k--',linewidth=1)
    plt.text(0,0.1,'0%', horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=12)
    plt.text(0.5,0.1,'50%', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)
    plt.text(1,0.1,'100%', horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=12)
    plt.text(expected_value,0.1,'%0.1f%%' % (expected_value*100), horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)
    plt.text(expected_value,0.3,'Prior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)
    
    # data
    current_value = expected_value
    y_value = -0.3
    for i in range(len(features_)):
        if shap_[i] < 0:
            plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#5bc2ae')
            if type(original_[i]) == str:
                plt.text(current_value+0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='left', verticalalignment='center', fontsize=12)
            else:
                plt.text(current_value+0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='left', verticalalignment='center', fontsize=12)
        else:
            plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#d93f20')
            if type(original_[i]) == str:
                plt.text(current_value-0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='right', verticalalignment='center', fontsize=12)
            else:
                plt.text(current_value-0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=12)
        current_value += shap_[i]
        y_value += -0.3
        
    # end line
    plt.plot([current_value,current_value],[y_value+0.15,0.1],'k--',linewidth=1)
    if current_value < 0.5:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=12, weight='bold')
    else:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=12, weight='bold')
    plt.text(current_value,0.3,'Posterior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)
        
    # limits
    plt.xlim(-0.02,1.02)
    plt.ylim(y_value-0.05, 0.8)
    plt.axis('off')

In [None]:
# iterate over patients
for a,sample in enumerate(samples):
    #if distribution_labels[a]==0 and y_vector[a]==1 and X_matrix.at[sample,'clinical_noninvasive # COHORT']!='LGG':
    if True:
    
        # get nonzero shap and feature values
        features_ = [feature for feature in merged_features if shap_value.loc[sample][feature] != 0]
        shap_ = shap_value.loc[sample][features_].tolist()
        original_ = X_matrix.loc[sample][features_].tolist()

        # convert numerical values to strings
        for j in range(len(original_)):
            if not type(original_[j]) == str:
                if type(original_[j]) == bool:
                    original_[j] = str(original_[j])
                elif not np.isnan(original_[j]):
                    if original_[j] == int(original_[j]):
                        original_[j] = str(int(original_[j]))
                    else:
                        original_[j] = str(original_[j])


        # order features based on absolute shap value
        sort_index = np.argsort(np.abs(shap_))[::-1]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]
        
        # group imputed values
        #imputed_id = [i for i,x in enumerate(original_) if pd.isna(x)]
        #imputed_sum = np.sum([shap_[i] for i in imputed_id])
        #for i in imputed_id[::-1]:
        #    del features_[i]
        #    del shap_[i]
        #    del original_[i]
        #features_.append('(Imputed)')
        #shap_.append(imputed_sum)
        #original_.append(np.nan)
        
        # create figure
        #if np.sum(shap_)+expected_value[a] > 0.5:
            #plot_patient(sample, y_vector[a], expected_value[a], features_, shap_, original_)
            #plt.savefig('%s/%s/%s/shap_values/patients/%s.png' % (folder_name, dataset_name, output_folder, sample), bbox_inches='tight', dpi=400)
            #plt.close()
        if True:
            print(expected_value[a]+np.sum(shap_))

In [None]:
# color functions
def adjust_color_lightness(r, g, b, factor):
    h, l, s = rgb_to_hls(r / 255.0, g / 255.0, b / 255.0)
    l = max(min(l * factor, 1.0), 0.0)
    r, g, b = hls_to_rgb(h, l, s)
    return int(r * 255), int(g * 255), int(b * 255)

def lighten_color(r, g, b, factor=0.1):
    return adjust_color_lightness(r, g, b, 1 + factor)

def darken_color(r, g, b, factor=0.1):
    return adjust_color_lightness(r, g, b, 1 - factor)

In [None]:
sample = 'TCGA-S9-A7IY-01A'
a = samples.index(sample)
feature_names = ['Karnofsky Score','12,13-DHOME','Patient Age',r'6-Keto-prostaglandin F1$\alpha$','D-Fructose','Urea','Valeric acid','Eicosatrienoic acid','Palmitic acid','Allopregnanolone','Behenic acid','Oxoadipic acid','Phosphorylcholine','Sphingosine','N-Acetylputrescine','Thromboxane A2','Lanosterin','Adipic acid','2-Methylcitric acid','Glyceraldehyde','Adenine','Hydroperoxylinoleic acid','Oleic acid','Dihydrobiopterin','4-Hydroxynonenal','Eicosenoic acid','Phenylacetylglutamine','Sphinganine','L-Arabitol','Xanthurenic acid','GTP','5-Hydroxyindoleacetic acid','Guanosine','Patient Race','4-Hydroxyproline','Quinolinic acid','Capric acid','Eicosatetraenoic acid',r'20$\alpha$-Dihydroprogesterone','Prostaglandin D2','Argininosuccinic acid','Elaidic acid','GDP','Butyric acid','Pyroglutamic acid','Allysine','Prostaglandin J2','Aminoadipic acid','Phenylacetic acid','5Z-Tetradecenoic acid','Cancer Type']

# get nonzero shap and feature values
features_ = [feature for feature in merged_features if shap_value.loc[sample][feature] != 0]
shap_ = shap_value.loc[sample][features_].tolist()
original_ = X_matrix.loc[sample][features_].tolist()

# convert numerical values to strings
for j in range(len(original_)):
    if not type(original_[j]) == str:
        if type(original_[j]) == bool:
            original_[j] = str(original_[j])
        elif not np.isnan(original_[j]):
            if original_[j] == int(original_[j]):
                original_[j] = str(int(original_[j]))
            else:
                original_[j] = str(original_[j])

# order features based on absolute shap value
sort_index = np.argsort(np.abs(shap_))[::-1]
features_ = [features_[i] for i in sort_index]
shap_ = [shap_[i] for i in sort_index]
original_ = [original_[i] for i in sort_index]

# only keep top cumulative 95%
shap_normalized = [np.abs(x)/np.sum(np.abs(shap_)) for x in shap_]
threshold_index = [i for i,x in enumerate(np.cumsum(shap_normalized)) if x<0.95][-1]
features_ = features_[:threshold_index]
shap_ = shap_[:threshold_index]
original_ = original_[:threshold_index]
                
# order features based on shap value
sort_index = np.argsort(shap_)
features_ = [features_[i] for i in sort_index]
shap_ = [shap_[i] for i in sort_index]
original_ = [original_[i] for i in sort_index]

# remove imputed features
sort_index = [i for i,x in enumerate(original_) if not pd.isna(x)]
features_ = [features_[i] for i in sort_index]
shap_ = [shap_[i] for i in sort_index]
original_ = [original_[i] for i in sort_index]

# plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(15,5))
ax = fig.add_subplot(111)
barlist = plt.bar([x+1 for x in list(range(len(shap_)))], shap_)
for i in range(len(barlist)):
    if features_[i]=='clinical_noninvasive # COHORT':
        barlist[i].set_color('#e6c200')
    else:
        barlist[i].set_color(colors[individual_datasets.index(features_[i].split(' # ')[0])])
    if shap_[i] < 0:
        plt.text(i+1,0+0.0004,feature_names[i], fontsize=12, rotation=90, horizontalalignment='center', verticalalignment='bottom')
    else:
        plt.text(i+1,0-0.0001,feature_names[i], fontsize=12, rotation=90, horizontalalignment='center', verticalalignment='top')
plt.plot([0,len(shap_)+1],[0,0],'k-',linewidth=1)
plt.xlim([0,len(shap_)+1])
plt.ylim([-0.02,0.05])
plt.xticks([])
plt.yticks([-0.02,-0.01,0,0.01,0.02,0.03,0.04,0.052],labels=['-2%','-1%','0%','1%','2%','3%','4%','13%'])
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
plt.ylabel(r'$\Delta$P', fontsize=20)
#plt.text(1,0.04,'%s\n%d y.o. %s %s\nCancer Type: %s' % (sample[:12],X_matrix.at[sample,'clinical_noninvasive # AGE'],X_matrix.at[sample,'clinical_noninvasive # RACE'].title(),X_matrix.at[sample,'clinical_noninvasive # GENDER'].title(),X_matrix.at[sample,'clinical_noninvasive # COHORT']), fontsize=17)
plt.text(1,0.043,'%s\nClass: ' % (sample[:12]), fontsize=17)
if y_vector[a]==0:
    plt.text(4.1,0.043,'Sensitive', fontsize=17, weight='bold', color='#5bc2ae')
else:
    plt.text(4.1,0.043,'Resistant', fontsize=17, weight='bold', color='#d93f20')
plt.savefig('%s/%s/%s/shap_values/example/bar.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
# shap plots
for feature in features_:
    if feature.split(' # ')[0]=='objscreen_blood' or feature.split(' # ')[1] in ['KARNOFSKY','AGE']:
    
        # get values
        val = X_matrix[feature].values.tolist()
        sh = shap_value[feature].values.tolist()
        patient_index = samples.index(sample)
        keep_index = [i for i,x in enumerate(val) if not pd.isna(x)]
        val = [val[i] for i in keep_index]
        sh = [sh[i] for i in keep_index]
        patient_index = keep_index.index(patient_index)

        # plot
        plt.rcParams['font.family'] = 'Arial'
        plt.rc('xtick', labelsize=20)
        plt.rc('ytick', labelsize=20)
        fig = plt.figure(figsize=(5,5))
        ax = fig.add_subplot(111)
        plt.plot(val, sh, 'k.')
        plt.plot(val[patient_index], sh[patient_index], 'r.')
        plt.title(feature)
        if feature.split(' # ')[0]=='objscreen_blood':
            ax.set_xscale('log')
        plt.show()

In [None]:
feature = 'objscreen_blood # M02745'
title = '5Z-Tetradecenoic acid'

# get values
val = X_matrix[feature].values.tolist()
sh = shap_value[feature].values.tolist()
patient_index = samples.index(sample)
keep_index = [i for i,x in enumerate(val) if not pd.isna(x)]
val = [val[i] for i in keep_index]
sh = [sh[i] for i in keep_index]
patient_index = keep_index.index(patient_index)

# plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111)
plt.plot(val, sh, 'k.', markersize=5)
plt.plot(val[patient_index], sh[patient_index], '.', markersize=12, color='#d93f20')
plt.plot([0.001,0.1],[0,0],'k--',linewidth=1)
plt.title(title, fontsize=18)
ax.set_xscale('log')
plt.xlim([0.001,0.1])
plt.ylim([-0.02,0.04])
plt.yticks([-0.02,0,0.02,0.04])
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.xlabel(r'Metabolite Production [mmol gDW$^{-1}$ hr$^{-1}$]', fontsize=16)
plt.ylabel(r'$\Delta$P', fontsize=20)
plt.savefig('%s/%s/%s/shap_values/example/met1.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
feature = 'objscreen_blood # pheacgln'
title = 'Phenylacetylglutamine'

# get values
val = X_matrix[feature].values.tolist()
sh = shap_value[feature].values.tolist()
patient_index = samples.index(sample)
keep_index = [i for i,x in enumerate(val) if not pd.isna(x)]
val = [val[i] for i in keep_index]
sh = [sh[i] for i in keep_index]
patient_index = keep_index.index(patient_index)

# plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111)
plt.plot(val, sh, 'k.', markersize=5)
plt.plot(val[patient_index], sh[patient_index], '.', markersize=12, color='#d93f20')
plt.plot([0.00001,0.1],[0,0],'k--',linewidth=1)
plt.title(title, fontsize=18)
ax.set_xscale('log')
plt.xlim([0.00001,0.1])
plt.ylim([-0.01,0.025])
plt.yticks([-0.01,0,0.01,0.02])
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.xlabel(r'Metabolite Production [mmol gDW$^{-1}$ hr$^{-1}$]', fontsize=16)
plt.ylabel(r'$\Delta$P', fontsize=20)
plt.savefig('%s/%s/%s/shap_values/example/met2.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
feature = 'objscreen_blood # pheacgln'
title = 'Phenylacetylglutamine'

# get values
val = X_matrix[feature].values.tolist()
sh = shap_value[feature].values.tolist()
patient_index = samples.index(sample)
keep_index = [i for i,x in enumerate(val) if not pd.isna(x)]
val = [val[i] for i in keep_index]
sh = [sh[i] for i in keep_index]
patient_index = keep_index.index(patient_index)

# plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111)
plt.plot(val, sh, 'k.', markersize=5)
plt.plot(val[patient_index], sh[patient_index], '.', markersize=12, color='#d93f20')
plt.plot([0.00001,0.1],[0,0],'k--',linewidth=1)
plt.title(title, fontsize=18)
ax.set_xscale('log')
plt.xlim([0.00001,0.1])
plt.ylim([-0.01,0.025])
plt.yticks([-0.01,0,0.01,0.02])
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.xlabel(r'Metabolite Production [mmol gDW$^{-1}$ hr$^{-1}$]', fontsize=16)
plt.ylabel(r'$\Delta$P', fontsize=20)
plt.savefig('%s/%s/%s/shap_values/example/met2.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
# arrow plot
p1 = np.sum([x for i,x in enumerate(shap_) if features_[i].split(' # ')[0]=='clinical_noninvasive' and features_[i]!='clinical_noninvasive # COHORT'])
p2 = np.sum([x for i,x in enumerate(shap_) if features_[i]=='clinical_noninvasive # COHORT'])
p3 = np.sum([x for i,x in enumerate(shap_) if features_[i].split(' # ')[0]=='objscreen_blood'])
shap_ = [p1,p2,p3]
features_ = ['Other Non-Invasive Clinical','Cancer Type','Blood Metabolites']
colors_ = ['#FFD700','#e6c200','#DC143C']

# initialize figure
plt.rcParams['font.family'] = 'Arial'
fig = plt.figure(figsize=(6,2))
plt.xticks([])
plt.yticks([])

# number line
plt.plot([0,1],[0,0],'k-',linewidth=1)
for x in [0,0.5,1]:
    plt.plot([x,x],[-0.1,0.1],'k-',linewidth=1)
plt.plot([expected_value[a],expected_value[a]],[-0.3,0.1],'k--',linewidth=1)
plt.text(0,0.1,'0%', horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=12)
plt.text(0.5,0.1,'50%', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)
plt.text(1,0.1,'100%', horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=12)
plt.text(expected_value[a],0.1,'%0.1f%%' % (expected_value[a]*100), horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)
plt.text(expected_value[a],0.3,'Prior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)

# data
current_value = expected_value[a]
y_value = -0.3
for i in range(len(features_)):
    plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color=colors_[i])
    if shap_[i] < 0: 
        plt.text(current_value+0.005, y_value, '%s  ' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=12)
        plt.text(current_value+shap_[i]/2.3, y_value+0.04, '  %0.1f%%' % (shap_[i]*100), horizontalalignment='left', fontsize=8)
    else:
        plt.text(current_value-0.005, y_value, '%s' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=12)
        plt.text(current_value+shap_[i]/2.3, y_value+0.04, '+%0.1f%%' % (shap_[i]*100), horizontalalignment='center', fontsize=8)
    current_value += shap_[i]
    y_value += -0.3

# end line
plt.plot([current_value,current_value],[y_value+0.15,0.1],'k--',linewidth=1)
plt.plot([0.5,0.5],[y_value+0.15,0.1],'k--',linewidth=1)
if current_value < 0.5:
    plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=12, weight='bold')
else:
    plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=12, weight='bold')
plt.text(current_value,0.3,'Posterior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=12)

# limits
plt.xlim(-0.02,1.02)
plt.ylim(y_value-0.05, 0.8)
plt.axis('off')
plt.savefig('%s/%s/%s/shap_values/example/arrows.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
# dataset contribution
contribution_names = ['Other Non-Invasive Clinical','Cancer Type','Blood Metabolites']
contribution_values = [np.sum([np.abs(x) for i,x in enumerate(shap_) if features_[i].split(' # ')[0]=='clinical_noninvasive' and features_[i]!='clinical_noninvasive # COHORT']), np.sum([np.abs(x) for i,x in enumerate(shap_) if features_[i]=='clinical_noninvasive # COHORT']),  np.sum([np.abs(x) for i,x in enumerate(shap_) if features_[i].split(' # ')[0]=='objscreen_blood'])]
contribution_values = [x/np.sum(contribution_values) for x in contribution_values]
plt.rcParams['font.family'] = 'Arial'
fig = plt.figure(figsize=(8,8))
ax = fig.add_subplot(111)
patches, texts, autotexts = plt.pie(contribution_values, colors=['#FFD700','#e6c200','#DC143C'], autopct='%1.1f%%')
for i in range(len(texts)):
    texts[i].set_fontsize(26)
    autotexts[i].set_fontsize(26)
plt.show()

In [None]:
for a,sample in enumerate(samples):
    if y_vector[a]==1:
    
        # get nonzero shap and feature values
        features_ = [feature for feature in merged_features if shap_value.loc[sample][feature] != 0]
        shap_ = shap_value.loc[sample][features_].tolist()
        original_ = X_matrix.loc[sample][features_].tolist()

        # convert numerical values to strings
        for j in range(len(original_)):
            if not type(original_[j]) == str:
                if type(original_[j]) == bool:
                    original_[j] = str(original_[j])
                elif not np.isnan(original_[j]):
                    if original_[j] == int(original_[j]):
                        original_[j] = str(int(original_[j]))
                    else:
                        original_[j] = str(original_[j])

        # order features based on absolute shap value
        sort_index = np.argsort(np.abs(shap_))[::-1]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        # only keep top cumulative 95%
        shap_normalized = [np.abs(x)/np.sum(np.abs(shap_)) for x in shap_]
        threshold_index = [i for i,x in enumerate(np.cumsum(shap_normalized)) if x<0.95][-1]
        features_ = features_[:threshold_index]
        shap_ = shap_[:threshold_index]
        original_ = original_[:threshold_index]

        # order features based on shap value
        sort_index = np.argsort(shap_)
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        # remove imputed features
        sort_index = [i for i,x in enumerate(original_) if not pd.isna(x)]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        p1 = np.sum([x for i,x in enumerate(shap_) if features_[i].split(' # ')[0]=='clinical_noninvasive' and features_[i]!='clinical_noninvasive # COHORT'])
        p2 = np.sum([x for i,x in enumerate(shap_) if features_[i]=='clinical_noninvasive # COHORT'])
        p3 = np.sum([x for i,x in enumerate(shap_) if features_[i].split(' # ')[0]=='objscreen_blood'])
        if p3>(p1+p2) and (expected_value[a]+p1+p2+p3)>0.5 and (expected_value[a]+p1+p2)<0.5:
            print(sample, X_matrix.at[sample,'clinical_noninvasive # COHORT'], p1, p2, p3, (expected_value[a]+p1+p2), (expected_value[a]+p1+p2+p3))

### Stacker

#### Load data

In [None]:
if analyze_stacker:
    
    # load stacker data from each split
    stacker_samples = []
    stacker_weights_ = []
    stacker_performance = []
    for i in range(n_splits):
        with open('%s/%s/stacker_%d.pickle' % (folder_name, dataset_name, i+1) ,'rb') as f:
            #X_test, X_test_all, y_test, test_predictions, testing_pred, weights = pickle.load(f)
            X_test, X_test_all, y_test, y_pred, test_predictions, weights, test_best_classifier, bst = pickle.load(f)
        for j,sample in enumerate(X_test[0].index.tolist()):
            if sample in stacker_samples:
                for k,dataset in enumerate(individual_datasets):
                    stacker_weights_[stacker_samples.index(sample)][k].append(weights[j,k])
                stacker_performance[stacker_samples.index(sample)].append(stacker_performance_weights[i])
            else:
                stacker_samples.append(sample)
                stacker_weights_.append([])
                for k,dataset in enumerate(individual_datasets):
                    stacker_weights_[-1].append([weights[j,k]])
                stacker_performance.append([stacker_performance_weights[i]])
                
    # weighted average of stacker predictions
    stacker_weights = []
    for i in range(len(stacker_samples)):
        stacker_weights.append([])
        for j in range(len(individual_datasets)):
            stacker_weights[-1].append(np.average(stacker_weights_[i][j], weights=stacker_performance[i]))
            
    # plot of stacker weights
    stacker_weights_grouped = []
    for i in range(len(individual_datasets)):
        stacker_weights_grouped.append([x[i] for x in stacker_weights])
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=20)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(3*len(individual_datasets),10))
    ax = fig.add_subplot(111)
    fig = sns.boxplot(data=stacker_weights_grouped, whis=1, fliersize=0, palette=colors)
    fig = sns.swarmplot(data=stacker_weights_grouped, color='black', alpha=0.25)
    plt.xlim([-0.5,len(individual_datasets)-0.5])
    fig.set(xticklabels=individual_datasets_labels)
    fig.spines['right'].set_visible(False)
    fig.spines['top'].set_visible(False)
    plt.ylabel('Stacker Weight', fontsize=20)
    plt.savefig('%s/%s/%s/stacker/summary/weights.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()

#### Separate based on clinical factors

In [None]:
if analyze_stacker:
    
    # separate based on clinical factors
    if 'clinical' in individual_datasets:
        
        # initialize feature and pvalue lists
        separation_feature = []
        separation_pvalue = []
        separation_feature_values = []
        separation_weight_factors = []
        
        # iterate over clinical factors
        for feature in merged_features:
            if feature.split(' # ')[0] == 'clinical':
                
                # categorical
                if feature.split(' # ')[1] not in ['AGE','PACK YEARS','ALCOHOL PER DAY']:
                    
                    # if one-hot encoded
                    if feature in categorical_conversion:
                        
                        # unique values
                        unique_values = sorted(categorical_conversion[feature])
                        
                        # contingency table
                        weight_factors = [[[] for dataset in individual_datasets] for value in unique_values]                
                        for i,sample in enumerate(stacker_samples):
                            for j,value in enumerate(unique_values):
                                if X_matrix.at[sample, '%s | %s' % (feature, value)] == 1:
                                    for k in range(len(individual_datasets)):
                                        weight_factors[j][k].append(stacker_weights[i][k])
                        separation_feature.append(feature.split(' # ')[1])
                        for i,value in enumerate(unique_values):
                            separation_feature_values.append('%s | %s' % (feature, value))
                            separation_weight_factors.append(weight_factors[i])
                        if len(individual_datasets) > 2:
                            separation_pvalue.append([])
                            for i in range(len(individual_datasets)):
                                separation_pvalue[-1].append(f_oneway(*[weight_factors[j][i] for j in range(len(unique_values))]).pvalue)
                        else:
                            separation_pvalue.append(f_oneway(*[weight_factors[j][0] for j in range(len(unique_values))]).pvalue)
                             
                    # if not one-hot encoded
                    else:
                        
                        # get all unique values
                        unique_values = sorted(list(set([x for x in X_matrix[feature].values.tolist() if not pd.isna(x)])))
                        
                        # contingency table
                        weight_factors = [[[] for dataset in individual_datasets] for value in unique_values]    
                        for i,sample in enumerate(stacker_samples):
                            if not pd.isna(X_matrix.at[sample,feature]):
                                for k in range(len(individual_datasets)):
                                    weight_factors[unique_values.index(X_matrix.at[sample,feature])][k].append(stacker_weights[i][k])
                        separation_feature.append(feature.split(' # ')[1])
                        for i,value in enumerate(unique_values):
                            separation_feature_values.append('%s | %s' % (feature, value))
                            separation_weight_factors.append(weight_factors[i])
                        if len(individual_datasets) > 2:
                            separation_pvalue.append([])
                            for i in range(len(individual_datasets)):
                                separation_pvalue[-1].append(f_oneway(*[weight_factors[j][i] for j in range(len(unique_values))]).pvalue)
                        else:
                            separation_pvalue.append(f_oneway(*[weight_factors[j][0] for j in range(len(unique_values))]).pvalue)
                            
        # find top feature value for each dataset
        for i,dataset in enumerate(individual_datasets):
            
            # get mean dataset weight
            weight_mean = []
            weight_values = []
            feature_value = []
            for j,value in enumerate(separation_feature_values):
                if len(separation_weight_factors[j][i]) >= 3:
                    feature_value.append(value.split(' # ')[1])
                    weight_mean.append(np.mean(separation_weight_factors[j][i]))
                    weight_values.append(separation_weight_factors[j][i])
                    
            # sort weights
            sort_index = np.argsort(weight_mean)[::-1]
            weight_mean = [weight_mean[a] for a in sort_index]
            weight_values = [weight_values[a] for a in sort_index]
            feature_value = [feature_value[a] for a in sort_index]
            
            # plot
            N = 20
            plt.rcParams['font.family'] = 'Arial'
            plt.rc('xtick', labelsize=12)
            plt.rc('ytick', labelsize=20)
            fig = plt.figure(figsize=(20,5))
            ax = fig.add_subplot(111)
            fig = sns.boxplot(data=weight_values[:N], whis=1, fliersize=0, palette=[colors[i] for a in range(N)])
            fig = sns.swarmplot(data=weight_values[:N], color='black', alpha=0.25)
            fig.set_xticklabels(labels=feature_value[:N], rotation=90)
            plt.xlim(-1,N)
            plt.ylabel(r'Stacker Weight', fontsize=20)
            plt.title('Top %d Features - %s' % (N, individual_datasets_labels[i]), fontsize=24)
            plt.savefig('%s/%s/%s/stacker/summary/top%d_%s.png' % (folder_name, dataset_name, output_folder, N, dataset), bbox_inches='tight', dpi=400)
            plt.close()
            
        # find top feature value for all combined except clinical
        if (len(individual_datasets) >= 3) and ('clinical' in individual_datasets):
            
            # get mean dataset weight
            weight_mean = []
            weight_values = []
            feature_value = []
            for j,value in enumerate(separation_feature_values):
                if len(separation_weight_factors[j][0]) >= 3:
                    feature_value.append(value.split(' # ')[1])
                    for i, dataset in enumerate(individual_datasets):
                        if dataset != 'clinical':
                            if len(weight_mean) < len(feature_value):
                                weight_mean.append(np.mean(separation_weight_factors[j][i]))
                                weight_values.append(np.array(separation_weight_factors[j][i]))
                            else:
                                weight_mean[-1] += np.mean(separation_weight_factors[j][i])
                                weight_values[-1] = np.add(weight_values[-1],separation_weight_factors[j][i])
                    
            # sort weights
            sort_index = np.argsort(weight_mean)[::-1]
            weight_mean = [weight_mean[a] for a in sort_index]
            weight_values = [weight_values[a] for a in sort_index]
            feature_value = [feature_value[a] for a in sort_index]
            
            # plot
            N = 20
            plt.rcParams['font.family'] = 'Arial'
            plt.rc('xtick', labelsize=12)
            plt.rc('ytick', labelsize=20)
            fig = plt.figure(figsize=(20,5))
            ax = fig.add_subplot(111)
            fig = sns.boxplot(data=weight_values[:N], whis=1, fliersize=0, palette=['#D3D3D3' for a in range(N)])
            fig = sns.swarmplot(data=weight_values[:N], color='black', alpha=0.25)
            fig.set_xticklabels(labels=feature_value[:N], rotation=90)
            plt.xlim(-1,N)
            plt.ylabel(r'Stacker Weight', fontsize=20)
            plt.title('Top %d Features - All Datasets but Clinical' % (N), fontsize=24)
            plt.savefig('%s/%s/%s/stacker/summary/top%d_allbutclinical.png' % (folder_name, dataset_name, output_folder, N), bbox_inches='tight', dpi=400)
            plt.close()