In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import pyplot, lines
import pickle
import random
import seaborn as sns
from scipy.stats import f_oneway
import urllib
import json
import time
from scipy.stats import ttest_ind, spearmanr
import glob
import collections

from sklearn.metrics import pairwise_distances, roc_curve, jaccard_score
from sklearn.cluster import KMeans
from scipy.stats import pearsonr, f_oneway, linregress
from sklearn.metrics import silhouette_score, silhouette_samples
from scipy.stats import f_oneway, linregress, chi2_contingency
from statsmodels.stats.multitest import multipletests

### Input

Folder and datset names

In [None]:
folder_name = 'withhyperopt_3'
individual_datasets = ['clinical','gene_all','mutation_onehot_all','objscreen_kegg']
individual_datasets_labels = ['Clinical','Gene Expression','Mutation','Metabolites']
colors = ['#FFD700','#3DAEC5','#2E8B57','#DC143C']
n_splits = 20

Output folder name

In [None]:
output_folder = '_shap'

Analysis options

In [None]:
analyze_stacker = True
analyze_shap_values = True
analyze_shap_interactions = False

### Create output folders

In [None]:
# name of dataset
dataset_name = '+'.join(individual_datasets)

# output folders
if not os.path.isdir('%s/%s/%s' % (folder_name, dataset_name, output_folder)):
    os.mkdir('%s/%s/%s' % (folder_name, dataset_name, output_folder))
    
    # stacker
    if analyze_stacker:
        os.mkdir('%s/%s/%s/stacker' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/stacker/summary' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/stacker/clinical_features' % (folder_name, dataset_name, output_folder))
        
    # shap values
    if analyze_shap_values:
        os.mkdir('%s/%s/%s/shap_values' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/summary' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/summary/cohort' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/features' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/patients' % (folder_name, dataset_name, output_folder))
    
    # shap interactions
    if analyze_shap_interactions:
        os.mkdir('%s/%s/%s/shap_interactions' % (folder_name, dataset_name, output_folder))

### Load dataset

In [None]:
# create merged dataset
with open('_datasets/%s.pickle' % individual_datasets[0], 'rb') as f:
    X_matrix, y_vector, categorical_conversion_old = pickle.load(f, encoding='latin1')
X_matrix.columns = ['%s # %s' % (individual_datasets[0], x) for x in X_matrix.columns.tolist()]
categorical_conversion = {}
for key in categorical_conversion_old:
    categorical_conversion['%s # %s' % (individual_datasets[0], key)] = categorical_conversion_old[key]
for b in range(1,len(individual_datasets)):
    with open('_datasets/%s.pickle' % individual_datasets[b], 'rb') as f:
        X_matrix_, y_vector_, categorical_conversion_old = pickle.load(f, encoding='latin1')
    X_matrix_.columns = ['%s # %s' % (individual_datasets[b], x) for x in X_matrix_.columns.tolist()]
    categorical_conversion_ = {}
    for key in categorical_conversion_old:
        categorical_conversion_['%s # %s' % (individual_datasets[b], key)] = categorical_conversion_old[key]
    if X_matrix.index.tolist() == X_matrix_.index.tolist():
        X_matrix = pd.concat([X_matrix, X_matrix_], axis=1)
        categorical_conversion = {**categorical_conversion, **categorical_conversion_}
    else:
        raise Exception("Dataset sample lists don't match")
samples = X_matrix.index.tolist()
features =  X_matrix.columns.tolist()

# list of merged features
if len(categorical_conversion) > 0:
    merged_features = []
    for feature in features:
        if feature.split(' | ')[0] not in categorical_conversion:
            merged_features.append(feature)
        elif feature.split(' | ')[0] not in merged_features:
            merged_features.append(feature.split(' | ')[0])
else:
    merged_features = features.copy()
    
# merge feature matrix
X_matrix_ = X_matrix.copy()
for feature in categorical_conversion:
    X_matrix_[feature] = np.nan
    X_matrix_[feature] = X_matrix_[feature].astype(object)
    
    # get values for categorical features
    values = []
    for sample in X_matrix.index.tolist():
        for val in categorical_conversion[feature]:
            if X_matrix.at[sample,'%s | %s' % (feature, val)] == 1:
                X_matrix_.at[sample, feature] = val
X_matrix = X_matrix_[merged_features]

### Load performance

In [None]:
# stacker - log loss
if analyze_stacker:
    df = pd.read_csv('%s/stacker_logloss.csv' % folder_name, index_col=0)
    stacker_log_loss = df.loc[['split_%d' % (i+1) for i in range(n_splits)]][dataset_name].values.tolist()
    stacker_performance_weights = [1/x for x in stacker_log_loss]

# classifier - weighted log loss
df = pd.read_csv('%s/weightedlogloss.csv' % folder_name, index_col=0)
classifier_weighted_log_loss = df.loc[['split_%d' % (i+1) for i in range(n_splits)]][dataset_name].values.tolist()
classifier_performance_weights = [1/x for x in classifier_weighted_log_loss]

### SHAP values

#### Load data

In [None]:
if analyze_shap_values:
        
    # load shap values
    shap_expected = []
    shap_values = []
    for i in range(n_splits):
        with open('%s/%s/shap_values_%d.pickle' % (folder_name, dataset_name, i+1) ,'rb') as f:
            shap_expected_, shap_values_ = pickle.load(f)
            shap_expected.append(shap_expected_)
            shap_values.append(shap_values_)
            
    # remove samples that aren't analyzed
    keep_index = []
    for i,sample in enumerate(samples):
        to_keep = False
        for j in range(n_splits):
            if sample in shap_values[j].index.tolist():
                to_keep = True
        if to_keep:
            keep_index.append(i)
    samples = [samples[i] for i in keep_index]
    y_vector = [y_vector[i] for i in keep_index]
    X_matrix = X_matrix.iloc[keep_index]
    
    # combine expected values
    _expected_value_sum = [0 for sample in samples]
    _shap_weight = [0 for sample in samples]
    for i,sample in enumerate(samples):
        for j in range(n_splits):
            if sample in shap_values[j].index.tolist():
                _expected_value_sum[i] += classifier_performance_weights[j] * shap_expected[j][shap_values[j].index.tolist().index(sample)]
                _shap_weight[i] += classifier_performance_weights[j]
    expected_value = np.divide(_expected_value_sum, _shap_weight)
    
    # combine shap values
    for i in range(n_splits):
        samples_not_included = [x for x in samples if x not in shap_values[i].index.tolist()]
        shap_values[i] = pd.concat([shap_values[i], pd.DataFrame(data=0, index=samples_not_included, columns=shap_values[i].columns.tolist())])
        shap_values[i] = shap_values[i].loc[samples]
        features_not_included = [x for x in merged_features if x not in shap_values[i].columns.tolist()]
        shap_values[i] = pd.concat([shap_values[i], pd.DataFrame(data=0, index=samples, columns=features_not_included)], axis=1, sort=False)
        shap_values[i] = shap_values[i][merged_features]
    shap_value = pd.DataFrame(data=0, index=samples, columns=merged_features)
    for i in range(n_splits):
        shap_value = np.add(shap_value, classifier_performance_weights[i]*shap_values[i])
    for j in range(len(samples)):
        shap_value.loc[samples[j]] /= _shap_weight[j]
    
    # calculate absolute shap values
    shap_value_abs = shap_value.abs()
    
    # zero shap values for imputed features
    for sample in shap_value_abs.index.tolist():
        for feature in shap_value_abs.columns.tolist():
            if pd.isna(X_matrix.at[sample,feature]):
                shap_value_abs.at[sample,feature] = 0
    
    # normalize by difference between prior and posterior
    for i,sample in enumerate(samples):
        shap_value_abs.loc[sample] /= np.sum(shap_value_abs.loc[sample].values.tolist())
        
    # mean and standard error for each feature
    shap_value_abs_mean = shap_value_abs.mean(axis=0).values.tolist()
    shap_value_abs_mean_cumsum = np.cumsum(sorted(shap_value_abs_mean)[::-1])
    shap_value_abs_std = shap_value_abs.std(axis=0).values.tolist()
    shap_value_abs_sterr = [x/np.sqrt(len(samples)) for x in shap_value_abs_std]

In [None]:
shap_value.T.to_csv('dataset4.csv')

In [None]:
pd.DataFrame([x.split(' # ')[-1] for x in shap_value.columns.tolist()]).to_csv('dataset4_features.csv', header=False, index=False)

In [None]:
pd.DataFrame([x.split(' # ')[0] for x in shap_value.columns.tolist()]).to_csv('dataset4_sets.csv', header=False, index=False)

#### All features

In [None]:
if analyze_shap_values:

    # top 95% features
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:len([x for x in shap_value_abs_mean_cumsum if x<=0.95])]
    overall_top_features = [merged_features[i] for i in top_N_index]
    overall_top_values = [shap_value_abs_mean[i] for i in top_N_index]
    
    # top N features - mean
    top_N_index = np.argsort(shap_value_abs_mean)[::-1]
    top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]

    # plot - all shap values
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=18)
    plt.rc('ytick', labelsize=18)
    fig = plt.figure(figsize=(6,5))
    ax = fig.add_subplot(111)
    ax.set_xscale('log')
    ax.plot(range(1,len(merged_features)+1), sorted(shap_value_abs_mean)[::-1], 'k.-', markersize=5)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.ylim(-0.0001,0.12)
    plt.yticks([0,0.02,0.04,0.06,0.08,0.1,0.12])
    plt.xlabel('Features', fontsize=20)
    ax2 = ax.twinx()
    ax2.plot(range(1,len(merged_features)+1), shap_value_abs_mean_cumsum, '.-', color='#808080', markersize=5)
    ax2.plot([len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),len([x for x in shap_value_abs_mean_cumsum if x<=0.95])],[0,0.95],'--',color='#808080')
    ax2.plot([len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),1.1*len(merged_features)],[0.95,0.95],'--',color='#808080')
    ax2.plot([len([x for x in shap_value_abs_mean if x>0]),len([x for x in shap_value_abs_mean if x>0])],[0,1],'--',color='#808080')
    ax2.tick_params(axis='y', labelcolor='#808080')
    ax2.set_ylabel('Cumulative Sum', color='#808080', fontsize=20)
    ax2.set_ylim(0,1.01)
    ax2.set_yticks([0,0.25,0.5,0.75,0.95,1])
    num_int = len(str(len(merged_features)))
    plt.xticks([10**x for x in range(num_int)][:-1]+[len(merged_features)],[10**x for x in range(num_int)][:-1]+[len(merged_features)])
    plt.xlim(0.95,1.05*len(merged_features))  
    plt.text(len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),0.03,'%d' % len([x for x in shap_value_abs_mean_cumsum if x<=0.95]), weight='bold', ha='right', fontsize=20)
    plt.text(len([x for x in shap_value_abs_mean if x>0]),0.03,'%d' % len([x for x in shap_value_abs_mean if x>0]), ha='left', fontsize=20)
    #plt.text(len(merged_features),0.03,'%d' % len(merged_features), weight='bold', ha='right', fontsize=20)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    vals = ax2.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax2.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax2.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax2.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/all_features.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()

#### Dataset contributions

In [None]:
if analyze_shap_values:
    
    # plot - number of top features in each dataset
    number_of_top_features = []
    for dataset in individual_datasets:
        number_of_top_features.append(len([x for x in overall_top_features if x.split(' # ')[0]==dataset]))
    def func(pct, allvals):
        absolute = int(np.round(pct/100.*np.sum(number_of_top_features)))
        return "{:d}".format(absolute)
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    #patches, texts, autotexts = plt.pie(number_of_top_features, labels=individual_datasets_labels, colors=colors, autopct=lambda pct: func(pct, number_of_top_features))
    patches, texts, autotexts = plt.pie(number_of_top_features, colors=colors, autopct=lambda pct: func(pct, number_of_top_features))
    for i in range(len(texts)):
        texts[i].set_fontsize(24)
        autotexts[i].set_fontsize(30)
    plt.title('Number of Features', fontsize=30)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_number.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()
        
    # plot - dataset contribution
    dataset_contribution = []
    for dataset in individual_datasets:
        dataset_contribution.append(np.sum([x for i,x in enumerate(shap_value_abs_mean) if merged_features[i].split(' # ')[0]==dataset and merged_features[i] in overall_top_features]))
    dataset_contribution = [x/np.sum(dataset_contribution) for x in dataset_contribution]
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    #patches, texts, autotexts = plt.pie(dataset_contribution, labels=individual_datasets_labels, colors=colors, autopct='%1.1f%%')
    patches, texts, autotexts = plt.pie(dataset_contribution, colors=colors, autopct='%1.1f%%')
    for i in range(len(texts)):
        texts[i].set_fontsize(24)
        autotexts[i].set_fontsize(26)
    plt.title('Contribution of Features', fontsize=30)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()

In [None]:
# plot - individual patient dataset contribution
dataset_contribution_individual = [[] for dataset in individual_datasets]
dataset_contribution_individual_top = [[] for dataset in individual_datasets]
for j,sample in enumerate(samples):
    sample_values = shap_value_abs.loc[sample].values.tolist()
    for i,dataset in enumerate(individual_datasets):
        dataset_contribution_individual[i].append(np.sum([sample_values[a] for a in range(len(merged_features)) if merged_features[a].split(' # ')[0]==dataset and merged_features[a] in overall_top_features])/np.sum(sample_values))
    if np.sum([sample_values[a] for a in range(len(merged_features)) if merged_features[a].split(' # ')[0]==dataset and merged_features[a] in overall_top_features])/np.sum(sample_values) > (1/len(individual_datasets)):
        dataset_contribution_individual_top[i].append(sample)
        
# move clinical to end
dataset_contribution_individual_ = dataset_contribution_individual.copy()
dataset_contribution_individual_.append(dataset_contribution_individual_.pop(0))
individual_datasets_labels_ = individual_datasets_labels.copy()
individual_datasets_labels_.append(individual_datasets_labels_.pop(0))
colors_ = colors.copy()
colors_.append(colors_.pop(0))
        
plt.rcParams['font.family'] = 'Arial'
fig = plt.figure(figsize=(3*len(individual_datasets),10))
ax = fig.add_subplot(111)
plt.rc('xtick', labelsize=22)
plt.rc('ytick', labelsize=30)
fig = sns.boxplot(data=dataset_contribution_individual_, whis=1, fliersize=0, palette=colors_)
fig = sns.swarmplot(data=dataset_contribution_individual_, color='black', alpha=0.25)
plt.xlim([-0.5,len(individual_datasets)-0.5])
fig.set(xticklabels=individual_datasets_labels_)
fig.spines['right'].set_visible(False)
fig.spines['top'].set_visible(False)
plt.ylabel('Percent Contribution', fontsize=32)
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_individual.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
colors = ['#3DAEC5','#2E8B57','#DC143C','#FFD700']

In [None]:
plt.rcParams['font.family'] = 'Arial'
fig = plt.figure(figsize=(3*len(individual_datasets),10))
ax = fig.add_subplot(111)
plt.rc('xtick', labelsize=26)
plt.rc('ytick', labelsize=30)
fig = sns.boxplot(data=dataset_contribution_individual_, whis=1, fliersize=0, palette=colors)
fig = sns.swarmplot(data=dataset_contribution_individual_, color='black', alpha=0.25)
plt.xlim([-0.5,len(individual_datasets)-0.5])
fig.set(xticklabels=individual_datasets_labels_)
fig.spines['right'].set_visible(False)
fig.spines['top'].set_visible(False)
plt.ylabel('Percent Contribution', fontsize=32)
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_individual.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

#### Clinical Multimodal

In [None]:
# functions
seed_ = 1
def compute_inertia(a, X):
    W = [np.mean(pairwise_distances(X[a == c, :])) for c in np.unique(a)]
    return np.mean(W)

def compute_gap(clustering, data, k_max=5, n_references=100):
    if len(data.shape) == 1:
        data = data.reshape(-1, 1)
    reference = np.random.rand(*data.shape)
    reference_inertia = []
    for k in range(1, k_max+1):
        local_inertia = []
        for _ in range(n_references):
            clustering.n_clusters = k
            assignments = clustering.fit_predict(reference)
            local_inertia.append(compute_inertia(assignments, reference))
        reference_inertia.append(np.mean(local_inertia))
    
    ondata_inertia = []
    for k in range(1, k_max+1):
        clustering.n_clusters = k
        assignments = clustering.fit_predict(data)
        ondata_inertia.append(compute_inertia(assignments, data))
        
    gap = np.log(reference_inertia)-np.log(ondata_inertia)
    return gap, np.log(reference_inertia), np.log(ondata_inertia)

# calculate optimal number of distrubutions
k_max = 10
gap, reference_inertia, ondata_inertia = compute_gap(KMeans(random_state=seed_), np.array(dataset_contribution_individual[0]).reshape(-1,1), k_max=k_max)
optimal_number_of_distributions = np.argmax(gap)+1
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.plot(range(1,k_max+1), gap, 'k.-', markersize=5)
ax.plot([optimal_number_of_distributions, optimal_number_of_distributions],[0.15,np.max(gap)], 'k--')
plt.xlabel('Number of Distributions', fontsize=20)
plt.ylabel('Gap Statistic', fontsize=20)
plt.xticks([1,2,3,4,5,6,7,8,9,10])
plt.ylim([0.15,0.3])
plt.yticks([0.15,0.2,0.25,0.3])
plt.savefig('%s/%s/%s/shap_values/summary/clinical_kmeans_gap.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

# separate samples into distributions
kmeans = KMeans(n_clusters=optimal_number_of_distributions).fit(np.array(dataset_contribution_individual[0]).reshape(-1,1))
labels = kmeans.labels_
centers = kmeans.cluster_centers_.reshape(-1,)
distribution_labels = []
for val in labels:
    distribution_labels.append(sorted(centers).index(centers[val]))
distribution_names = ['low','medium','high']
distribution_names_labels = ['Low','Medium','High']

# histogram
values = dataset_contribution_individual[0]
n_bins = 25
_, bin_edges = np.histogram(values, bins=n_bins)
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
colors_ = ['#FFFACD','#F0E68C','#BFBB99']
for i in range(optimal_number_of_distributions):
    sns.distplot([values[a] for a in range(len(values)) if distribution_labels[a]==i], hist=True, kde=True, bins=bin_edges, color=colors_[i], hist_kws={'edgecolor':'black'}, kde_kws={'linewidth': 4})
plt.legend(['Low Clinical - %0.1f%%' % (len([x for x in distribution_labels if x==0])/len(distribution_labels)*100),'Medium Clinical - %0.1f%%' % (len([x for x in distribution_labels if x==1])/len(distribution_labels)*100),'High Clinical - %0.1f%%' % (len([x for x in distribution_labels if x==2])/len(distribution_labels)*100)], fontsize=16)
plt.xlabel('Percent Contribution - Clinical', fontsize=20)
plt.ylabel('Density', fontsize=20)    
plt.xlim(0,1)
plt.xticks([0,0.25,0.5,0.75,1],['0%','25%','50%','75%','100%'])
plt.savefig('%s/%s/%s/shap_values/summary/clinical_multimodal.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

# dataset contribution - low/med/high clinical
for a in range(len(distribution_names)):
    
    # subset shap values
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = shap_value_abs_distribution.mean(axis=0).values.tolist()
    
    # new top features
    shap_value_abs_distribution_cumsum = np.cumsum(sorted(shap_value_abs_distribution_mean)[::-1])
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:len([x for x in shap_value_abs_distribution_cumsum if x<=0.95])]
    new_overall_top_features = [merged_features[i] for i in top_N_index]
    
    # dataset contribution
    dataset_contribution_distribution = []
    for dataset in individual_datasets:
        dataset_contribution_distribution.append(np.sum([x for i,x in enumerate(shap_value_abs_distribution_mean) if merged_features[i].split(' # ')[0]==dataset and merged_features[i] in new_overall_top_features]))
    dataset_contribution_distribution = [x/np.sum(dataset_contribution_distribution) for x in dataset_contribution_distribution]
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(8,8))
    ax = fig.add_subplot(111)
    patches, texts, autotexts = plt.pie(dataset_contribution_distribution, colors=colors, autopct='%1.1f%%')
    plt.title('%s Clinical' % distribution_names_labels[a], fontsize=30)
    for i in range(len(texts)):
        texts[i].set_fontsize(24)
        autotexts[i].set_fontsize(26)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_clinical_%s.png' % (folder_name, dataset_name, output_folder, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

#### Top features

In [None]:
feature = 'objscreen_kegg # HC02065'
title = 'Triglyceride pool'

# get values
val = X_matrix[feature].values.tolist()
sh = shap_value[feature].values.tolist()
keep_index = [i for i,x in enumerate(val) if not pd.isna(x)]
val = [val[i] for i in keep_index]
sh = [sh[i] for i in keep_index]

# spearman correlation
from scipy.stats import spearmanr
r = spearmanr(val,sh).correlation

# plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111)
plt.plot(val, sh, 'k.', markersize=5)
plt.plot([0.00001,0.1],[0,0],'k--',linewidth=1)
plt.title(title, fontsize=18)
ax.set_xscale('log')
plt.xlim([1e-4,1e-2])
#plt.ylim([-0.01,0.025])
#plt.yticks([-0.01,0,0.01,0.02])
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.xlabel(r'Metabolite Production [mmol gDW$^{-1}$ hr$^{-1}$]', fontsize=15)
plt.ylabel(r'$\Delta$P', fontsize=20)
plt.savefig('%s/%s/%s/shap_values/example/met2.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
feature = 'objscreen_kegg # but'
title = 'Butyric acid'

# get values
val = X_matrix[feature].values.tolist()
sh = shap_value[feature].values.tolist()
keep_index = [i for i,x in enumerate(val) if not pd.isna(x)]
val = [val[i] for i in keep_index]
sh = [sh[i] for i in keep_index]

# spearman correlation
from scipy.stats import spearmanr
r = spearmanr(val,sh).correlation

# plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=16)
plt.rc('ytick', labelsize=16)
fig = plt.figure(figsize=(4,4))
ax = fig.add_subplot(111)
plt.plot(val, sh, 'k.', markersize=5)
plt.plot([1e-5,1e1],[0,0],'k--',linewidth=1)
plt.title(title, fontsize=18)
ax.set_xscale('log')
plt.xlim([1e-5,1e1])
plt.ylim([-0.01,0.03])
plt.yticks([-0.01,0,0.01,0.02,0.03])
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.xlabel(r'Metabolite Production [mmol gDW$^{-1}$ hr$^{-1}$]', fontsize=15)
plt.ylabel(r'$\Delta$P', fontsize=20)
plt.savefig('%s/%s/%s/shap_values/example/met.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
top_features_names = ['Histology','Cisplatin response','Cancer type','Pathologic M','Clinical stage','Temozolomide response','IDH1 SNP','Fluorouracil response','Cyclophosphamide response','Tumor location','Patient age','Pathologic T','Carboplatin response','Pathologic N','Pathologic stage','Paclitaxel response','Butyric acid','Prostaglandin D2','Heparan sulfate','Patient race','BRAF SNP','Patient gender','EGFR SNP','O-D-Mannosylprotein','Saccharopine','Digalactosylceramidesulfate','S-Glutaryldihydrolipoamide','Clinical grade','Estrone sulfate','CDH10','Pyroglutamic acid','3-Sulfinoalanine','CDK5R2','4-Phosphopantothenoylcysteine','LY75','dGTP','Clinical T','Prostaglandin J2','Phytosphingosine',"Pantetheine 4'-phosphate",'Elaidic acid','Smoking history','Triglyceride pool','Alcohol per day','HOMER2','Adrenic acid','P2RX6','Capric acid','DNAH9 SNP','Pack years']

In [None]:
if analyze_shap_values:
    
    # N value
    N = 50
    
    # top N features - mean
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
    top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    colors_ = []
    for dataset in top_N_datasets:
        colors_.append(colors[individual_datasets.index(dataset)])
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(20,5))
    ax = fig.add_subplot(111)
    barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5)
    for i in range(len(barlist)):
        barlist[i].set_color(colors_[i])
    plt.xticks(range(N), top_features_names, rotation=90, fontsize=14)
    plt.xlim(-1,N)
    plt.yticks([0,0.02,0.04,0.06,0.08,0.1,0.12])
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=22)
    plt.title('All Patients', fontsize=24)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.legend([barlist[colors_.index(x)] for x in colors], individual_datasets_labels, fontsize=18, frameon=False)
    plt.savefig('%s/%s/%s/shap_values/summary/top%d.png' % (folder_name, dataset_name, output_folder, N), bbox_inches='tight', dpi=400)
    plt.close()

In [None]:
top_features_names = ['Histology','IDH1 SNP','Cancer type','Pathologic M','Clinical stage','Pathologic T','Patient age','Butyric acid','Prostaglandin D2','Pathologic N','EGFR SNP','BRAF SNP','Heparan sulfate','O-D-Mannosylprotein','Pathologic Stage','Saccharopine','Digalactosylceramidesulfate','S-Glutaryldihydrolipoamide','CDH10','Estrone sulfate','Tumor location','3-Sulfinoalanine','Pyroglutamic acid',"Pantetheine 4'-phosphate",'CDK5R2','LY75','4-Phosphopantothenoylcysteine','Triglyceride pool','dGTP','Patient race','Prostaglandin J2','Patient gender','Clinical T','HOMER2','Elaidic acid','Phytosphingosine','P2RX6','DNAH9 SNP','Capric acid','TSC22D4','3-Sulfinylpyruvic acid','Adrenic acid','Phenylacetylglutamine','Clinical grade','Indole-5,6-quinone','INPP5J','GTP','Carboplatin response','SEMA5B','Phosphatidylinositol pool']

In [None]:
# top N features - mean - low clinical
for a in range(len(distribution_names))[0:1]:

    # get values
    shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
    shap_value_abs_distribution_mean = shap_value_abs_distribution.mean(axis=0).values.tolist()
    shap_value_abs_distribution_std = shap_value_abs_distribution.std(axis=0).values.tolist()
    shap_value_abs_distribution_sterr = [x/np.sqrt(len([y for y in distribution_labels if y==a])) for x in shap_value_abs_distribution_std]

    # plot
    top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:N]
    top_N_mean = [shap_value_abs_distribution_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_distribution_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    colors_ = []
    for dataset in top_N_datasets:
        colors_.append(colors[individual_datasets.index(dataset)])
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=12)
    plt.rc('ytick', labelsize=20)
    fig = plt.figure(figsize=(22,5))
    ax = fig.add_subplot(111)
    barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5)
    for i in range(len(barlist)):
        barlist[i].set_color(colors_[i])
    plt.xticks(range(N), top_features_names, rotation=90, fontsize=14)
    plt.xlim(-1,N)
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.title('%s Clinical Patients' % distribution_names_labels[a], fontsize=24)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.legend([barlist[colors_.index(x)] for x in colors], individual_datasets_labels, fontsize=18, frameon=False)
    plt.savefig('%s/%s/%s/shap_values/summary/top%d_clinical_%s.png' % (folder_name, dataset_name, output_folder, N, distribution_names[a]), bbox_inches='tight', dpi=400)
    plt.close()

#### Drug Response

In [None]:
# list of drugs
drugs = ['Cisplatin','Temozolomide','Fluorouracil','Cyclophosphamide','Carboplatin','Paclitaxel']
responses = ['Complete Response','Partial Response','Stable Disease','Clinical Progressive Disease']
responses_color = ['#5bc2ae','#a5d98f','#e6da81','#d93f20']

# initialize drug responses
response_sensitive = [[0,0,0,0] for drug in drugs]
response_resistant = [[0,0,0,0] for drug in drugs]

# iterate over drugs:
for i,drug in enumerate(drugs):

    # iterate over samples
    for j,sample in enumerate(samples):

        # record drug response
        if X_matrix.at[sample,'clinical # RESPONSE DRUG %s' % drug] in responses:
            if y_vector[j]==0:
                response_sensitive[i][responses.index(X_matrix.at[sample,'clinical # RESPONSE DRUG %s' % drug])] += 1
            else:
                response_resistant[i][responses.index(X_matrix.at[sample,'clinical # RESPONSE DRUG %s' % drug])] += 1
                
# convert to fraction
for i in range(len(response_sensitive)):
    response_sensitive[i] = [x/np.sum(response_sensitive[i]) for x in response_sensitive[i]]
for i in range(len(response_resistant)):
    response_resistant[i] = [x/np.sum(response_resistant[i]) for x in response_resistant[i]]
                
# stacked bar plot
width=0.4
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(15,5))
ax = fig.add_subplot(111)
plots_sensitive = []
plots_resistant = []
for i in range(len(responses)):
    plots_sensitive.append(plt.bar(np.arange(len(drugs)), [response_sensitive[a][i] for a in range(len(drugs))], bottom=[np.sum(response_sensitive[a][:i]) for a in range(len(drugs))], align='edge', width=-width, color=responses_color[i]))
    plots_resistant.append(plt.bar(np.arange(len(drugs)), [response_resistant[a][i] for a in range(len(drugs))], bottom=[np.sum(response_resistant[a][:i]) for a in range(len(drugs))], align='edge', width=width, color=responses_color[i]))
plt.xticks([-width/2+x for x in range(len(drugs))]+[width/2+x for x in range(len(drugs))], ['Sensitive' for x in range(len(drugs))]+['Resistant' for x in range(len(drugs))], fontsize=18, rotation=90)
plt.text(5.8,0.95,'Drug Response', fontsize=18)
plt.legend(plots_sensitive,responses, fontsize=16, bbox_to_anchor=(1.0,0.96), frameon=False)
plt.text(-0.5,-0.23,'Radiation\nResponse', fontsize=18, horizontalalignment='right')
for i in range(len(drugs)):
    plt.text(i,1.01,drugs[i], fontsize=16, horizontalalignment='center')
plt.yticks([0,0.2,0.4,0.6,0.8,1],['0%','20%','40%','60%','80%','100%'])
plt.ylabel('Percentage of Patients', fontsize=20)
plt.savefig('%s/%s/%s/shap_values/summary/drug.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

#### Mutations

In [None]:
# list of genes
genes_ = ['IDH1','BRAF','EGFR']

# initialize mutations
mutations = [{} for gene in genes_]
mutations_cohort = [{} for gene in genes_]
mutations_count = [[0,0] for gene in genes_]

# iterate over cancer types
for fn in glob.glob('_datasets/mutation/_data_/input/*.maf'):
    cohort = fn.split('\\')[-1].split('.maf')[0]
    
    # load file
    df_maf = pd.read_table(fn,skiprows=[0,1,2,3,4],header=0)
    
    # iterate over samples
    available_samples = [x[:16] for x in df_maf['Tumor_Sample_Barcode'].values.tolist()]
    for sample in [x for x in samples if x in available_samples]:
        
        # subset data
        get_rows = [i for i,x in enumerate(df_maf['Tumor_Sample_Barcode'].values.tolist()) if x[:16]==sample and df_maf.at[i,'Variant_Type']=='SNP']
        df_subset = df_maf.loc[get_rows]
        
        # iterate over genes
        for gene in genes_:
            if gene in df_subset['Hugo_Symbol'].values.tolist():
                df_subset_ = df_subset[df_subset['Hugo_Symbol']==gene].reset_index(drop=True)
                for i in range(df_subset_.shape[0]):
                    
                    # add to count
                    if y_vector[samples.index(sample)]==0:
                        mutations_count[genes_.index(gene)][0] += 1
                    else:
                        mutations_count[genes_.index(gene)][1] += 1
                        
                    # collect missense and nonsense mutations
                    if df_subset_.at[i,'Variant_Classification'] in ['Missense_Mutation','Nonsense_Mutation']:
                        if df_subset_.at[i,'HGVSp_Short'].split('p.')[-1] not in mutations[genes_.index(gene)]:
                            mutations[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]] = [0,0]
                            mutations_cohort[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]] = {}
                        if y_vector[samples.index(sample)]==0:
                            mutations[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]][0] += 1
                        else:
                            mutations[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]][1] += 1
                        if cohort not in mutations_cohort[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]]:
                            mutations_cohort[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]][cohort] = 0
                        mutations_cohort[genes_.index(gene)][df_subset_.at[i,'HGVSp_Short'].split('p.')[-1]][cohort] += 1

In [None]:
# bar chart
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=20)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(6,4))
ax = fig.add_subplot(111)
barWidth = 0.3
r1 = np.arange(4)
r2 = [x + barWidth for x in r1]
plt.bar(r1, [mutations_count[0][0], mutations[0]['R132H'][0], mutations_count[1][0], mutations[1]['V600E'][0]], color='#5bc2ae', width=barWidth, edgecolor='white', label='Sensitive')
plt.bar(r2, [mutations_count[0][1], mutations[0]['R132H'][1], mutations_count[1][1], mutations[1]['V600E'][1]], color='#d93f20', width=barWidth, edgecolor='white', label='Resistant')
plt.xticks([0.15, 1.15, 2.15, 3.15],['All SNP','R132H','All SNP','V600E'])
plt.ylabel('Number of Patients', fontsize=20)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
plt.legend(fontsize=18, bbox_to_anchor=(0.95,1), frameon=False)
trans = ax.get_xaxis_transform()
plt.plot([-0.3,1.55],[-0.14,-0.14], color='k', transform=trans, clip_on=False)
plt.text(0.65,-28,'IDH1',fontsize=22,horizontalalignment='center')
plt.plot([1.7,3.55],[-0.14,-0.14], color='k', transform=trans, clip_on=False)
plt.text(2.65,-28,'BRAF',fontsize=22,horizontalalignment='center')
plt.savefig('%s/%s/%s/shap_values/summary/mutation.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

#### Metabolites

In [None]:
# load Recon3D metabolite list
df_metabolites = pd.read_csv('../FBA-beta/data/recon/metabolites.tsv', sep='\t', index_col=0)

# list of kegg metabolites
top_mets = []
top_kegg = []
for feature in overall_top_features:
    if feature.split(' # ')[0] == 'objscreen_kegg':
        top_mets.append(feature.split(' # ')[1])
        available_compartments = [x[-2] for x in df_metabolites.index.tolist() if x[:-3]==feature.split(' # ')[1]]
        top_kegg.append(df_metabolites.at['%s[%s]' % (feature.split(' # ')[1], available_compartments[0]), 'KEGG'])
top_mets = sorted(list(set(top_mets)))
top_kegg = sorted(list(set(top_kegg)))

# output lists
with open('%s/%s/%s/shap_values/summary/top_mets.txt' % (folder_name, dataset_name, output_folder), 'w') as f:
    for met in top_mets:
        f.write('%s\n' % met)
with open('%s/%s/%s/shap_values/summary/top_kegg.txt' % (folder_name, dataset_name, output_folder), 'w') as f:
    for met in top_kegg:
        f.write('%s\n' % met)

In [None]:
# gene resulting in increased or decreased probability with increased expression
increase = []
increase_r = []
decrease = []
decrease_r = []
for feature in overall_top_features:
    if feature.split(' # ')[0] == 'objscreen_kegg':
    
        # expression and shap values
        expression = X_matrix[feature].values.tolist()
        shapvalue = shap_value[feature].values.tolist()

        # correlation coefficient
        r = spearmanr(expression, shapvalue).correlation
        if r > 0:
            increase.append(feature.split(' # ')[1])
            increase_r.append(r)
        else:
            decrease.append(feature.split(' # ')[1])
            decrease_r.append(r)
            
# output lists
with open('%s/%s/%s/shap_values/summary/top_mets_increase.txt' % (folder_name, dataset_name, output_folder), 'w') as f:
    for met in increase:
        f.write('%s\n' % met)
with open('%s/%s/%s/shap_values/summary/top_mets_increase.csv' % (folder_name, dataset_name, output_folder), 'w') as f:
    f.write('MET,R\n')
    for i in range(len(increase)):
        f.write('%s,%f\n' % (increase[i],increase_r[i]))
with open('%s/%s/%s/shap_values/summary/top_mets_decrease.txt' % (folder_name, dataset_name, output_folder), 'w') as f:
    for met in decrease:
        f.write('%s\n' % met)
with open('%s/%s/%s/shap_values/summary/top_mets_decrease.csv' % (folder_name, dataset_name, output_folder), 'w') as f:
    f.write('MET,R\n')
    for i in range(len(decrease)):
        f.write('%s,%f\n' % (decrease[i],decrease_r[i]))

In [None]:
top_met_score = overall_top_values[overall_top_features.index('objscreen_kegg # but')]

In [None]:
with open('%s/%s/%s/shap_values/summary/top_mets.tsv' % (folder_name, dataset_name, output_folder), 'w') as f:
    f.write('KEGG ID\tVMH ID\tMETABOLITE NAME\tFORMULA\tRELATIVE IMPORTANCE\n')
    for i,feature in enumerate(overall_top_features):
        if feature.split(' # ')[0] == 'objscreen_kegg':
            met = feature.split(' # ')[1]
            available_compartments = [x[-2] for x in df_metabolites.index.tolist() if x[:-3]==met]
            met_ = '%s[%s]' % (met, available_compartments[0])
            f.write('%s\t%s\t%s\t%s\t%f\n' % (df_metabolites.at[met_,'KEGG'], met, df_metabolites.at[met_,'NAME'], df_metabolites.at[met_,'FORMULA'], overall_top_values[i]/top_met_score))

#### Histology

In [None]:
# iterate over cancer types
for cohort in sorted(list(set(X_matrix['clinical # COHORT'].values.tolist()))):

    # find histologic subtypes
    histologies = []
    values = []

    # iterate over samples
    for i,sample in enumerate(samples):

        # record histology
        if type(X_matrix.at[sample,'clinical # HISTOLOGIC'])==str:
            if X_matrix.at[sample,'clinical # HISTOLOGIC'].split(' > ')[0]==cohort:
                if X_matrix.at[sample,'clinical # HISTOLOGIC'].split(' > ')[1] not in histologies:
                    histologies.append(X_matrix.at[sample,'clinical # HISTOLOGIC'].split(' > ')[1])
                    values.append([y_vector[i]])
                else:
                    values[histologies.index(X_matrix.at[sample,'clinical # HISTOLOGIC'].split(' > ')[1])].append(y_vector[i])

    # if multiple histologies
    if len(histologies) >= 2:
        print('----------')
        print(cohort)
        for i in range(len(histologies)):
            print('%d/%d (%0.1f%%) Resistant - %s' % (len([x for x in values[i] if x==1]), len(values[i]), len([x for x in values[i] if x==1])/len(values[i])*100, histologies[i] ))

In [None]:
for sample in samples:
    if X_matrix.at[sample,'clinical # HISTOLOGIC'] == 'BRCA > Mucinous Carcinoma':
        print(sample[:12])

In [None]:

    
    
    # plot - individual patient dataset contribution
    dataset_contribution_individual = [[] for dataset in individual_datasets]
    dataset_contribution_individual_top = [[] for dataset in individual_datasets]
    for j,sample in enumerate(samples):
        sample_values = shap_value_abs.loc[sample].values.tolist()
        for i,dataset in enumerate(individual_datasets):
            dataset_contribution_individual[i].append(np.sum([sample_values[a] for a in range(len(merged_features)) if merged_features[a].split(' # ')[0]==dataset])/np.sum(sample_values))
        if np.sum([sample_values[a] for a in range(len(merged_features)) if merged_features[a].split(' # ')[0]==dataset])/np.sum(sample_values) > (1/len(individual_datasets)):
            dataset_contribution_individual_top[i].append(sample)
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(3*len(individual_datasets),10))
    ax = fig.add_subplot(111)
    plt.rc('xtick', labelsize=20)
    plt.rc('ytick', labelsize=20)
    fig = sns.boxplot(data=dataset_contribution_individual, whis=1, fliersize=0, palette=colors)
    fig = sns.swarmplot(data=dataset_contribution_individual, color='black', alpha=0.25)
    plt.xlim([-0.5,len(individual_datasets)-0.5])
    fig.set(xticklabels=individual_datasets_labels)
    fig.spines['right'].set_visible(False)
    fig.spines['top'].set_visible(False)
    plt.ylabel('Percent Contribution', fontsize=20)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_individual.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()

In [None]:
    
    # top N features - mean - low/med/high clinical
    for a in range(len(distribution_names)):
        
        # get values
        shap_value_abs_distribution = shap_value_abs.loc[[samples[i] for i in range(len(samples)) if distribution_labels[i]==a]]
        shap_value_abs_distribution_mean = shap_value_abs_distribution.mean(axis=0).values.tolist()
        shap_value_abs_distribution_std = shap_value_abs_distribution.std(axis=0).values.tolist()
        shap_value_abs_distribution_sterr = [x/np.sqrt(len([y for y in distribution_labels if y==a])) for x in shap_value_abs_distribution_std]
        
        # plot
        top_N_index = np.argsort(shap_value_abs_distribution_mean)[::-1][:N]
        top_N_mean = [shap_value_abs_distribution_mean[i] for i in top_N_index]
        top_N_sterr = [shap_value_abs_distribution_sterr[i] for i in top_N_index]
        top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
        top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
        colors_ = []
        for dataset in top_N_datasets:
            colors_.append(colors[individual_datasets.index(dataset)])
        plt.rcParams['font.family'] = 'Arial'
        plt.rc('xtick', labelsize=12)
        plt.rc('ytick', labelsize=20)
        fig = plt.figure(figsize=(20,5))
        ax = fig.add_subplot(111)
        barlist = plt.bar(range(N), top_N_mean, yerr=top_N_sterr, capsize=5)
        for i in range(len(barlist)):
            barlist[i].set_color(colors_[i])
        plt.xticks(range(N), top_N_features, rotation=90)
        plt.xlim(-1,N)
        plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
        plt.title('Top %d Features' % N, fontsize=24)
        vals = ax.get_yticks()
        decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
        if len([x for x in decimals if x != '0']) > 0:
            number_of_places = np.max([len(x) for x in decimals])
        else:
            number_of_places = np.max([len(x) for x in decimals])-1
        if number_of_places == 0:
            ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
        elif number_of_places == 1:
            ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
        elif number_of_places == 2:
            ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
        else:
            raise Exception('Number of decimal places = %d' % number_of_places)
        plt.legend([barlist[colors_.index(x)] for x in colors], individual_datasets_labels, fontsize=16)
        plt.savefig('%s/%s/%s/shap_values/summary/top%d_clinical_%s.png' % (folder_name, dataset_name, output_folder, N, distribution_names[a]), bbox_inches='tight', dpi=400)
        plt.close()

#### Entrez gene lists for gene ontology

In [None]:
if analyze_shap_values:

    # gene lists for gene ontology - significant genes
    N = len([x for x in shap_value_abs_mean_cumsum if x<=0.95])
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    overall_top_features = top_N_features.copy()
    found_entrez = []
    stepsize = 400
    for i in range(int(np.ceil(len(top_N_features)/stepsize))):
        time.sleep(1)
        if i < len(top_N_features)/stepsize:
            gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
        else:
            gene_list = ','.join(top_N_features[(stepsize*i):])
        link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
        data = json.loads(urllib.request.urlopen(link).read())  
        for j in range(len(data)):
            if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                found_entrez.append(str(data[j]['Gene ID']))
    with open('%s/%s/%s/shap_values/summary/sig_genes.txt' % (folder_name, dataset_name, output_folder),'w') as f:
        for gene in sorted(list(set(found_entrez))):
            f.write('%s\n' % gene)
            
    # gene lists for gene ontology - reference genes
    N = len(merged_features)
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    found_entrez = []
    stepsize = 400
    for i in range(int(np.ceil(len(top_N_features)/stepsize))):
        time.sleep(1)
        if i < len(top_N_features)/stepsize:
            gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
        else:
            gene_list = ','.join(top_N_features[(stepsize*i):])
        link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
        data = json.loads(urllib.request.urlopen(link).read())  
        for j in range(len(data)):
            if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                found_entrez.append(str(data[j]['Gene ID']))
    with open('%s/%s/%s/shap_values/summary/ref_genes.txt' % (folder_name, dataset_name, output_folder),'w') as f:
        for gene in sorted(list(set(found_entrez))):
            f.write('%s\n' % gene)

#### Subset cohorts

In [None]:
if analyze_shap_values:
    
    # load cohort for each sample
    with open('_datasets/clinical.pickle', 'rb') as f:
        X_matrix_, y_vector_, categorical_conversion_ = pickle.load(f, encoding='latin1')
    cohorts = []
    for sample in X_matrix_.index.tolist():
        for cohort in categorical_conversion_['COHORT']:
            if X_matrix_.at[sample,'COHORT | %s' % cohort] == 1:
                cohorts.append(cohort)

    # iterate over datasets
    for cohort in ['ACC', 'BLCA', 'BRCA', 'CESC', 'COAD', 'DLBC', 'GBM', 'HNSC', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'PRAD', 'READ', 'SKCM', 'STAD', 'THCA', 'UCEC', 'UCS']:
        
        # subset patients within cohort
        samples_ = [samples[i] for i in range(len(samples)) if cohorts[i]==cohort]
        shap_value_abs_ = shap_value_abs.loc[samples_]
        
        # mean and standard error for each feature
        shap_value_abs_mean = shap_value_abs_.mean(axis=0).values.tolist()
        shap_value_abs_mean_cumsum = np.cumsum(sorted(shap_value_abs_mean)[::-1])
        shap_value_abs_std = shap_value_abs_.std(axis=0).values.tolist()
        shap_value_abs_sterr = [x/np.sqrt(len(samples)) for x in shap_value_abs_std]
        
        # top N features - mean
        N = len([x for x in shap_value_abs_mean_cumsum if x<=0.95])
        top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
        top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
        top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
        top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
        top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
        
        # gene lists for gene ontology - significant genes
        found_entrez = []
        stepsize = 400
        for i in range(int(np.ceil(len(top_N_features)/stepsize))):
            time.sleep(1)
            if i < len(top_N_features)/stepsize:
                gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
            else:
                gene_list = ','.join(top_N_features[(stepsize*i):])
            link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
            data = json.loads(urllib.request.urlopen(link).read())  
            for j in range(len(data)):
                if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                    found_entrez.append(str(data[j]['Gene ID']))
        with open('%s/%s/%s/shap_values/summary/cohort/%s_sig_genes.txt' % (folder_name, dataset_name, output_folder, cohort),'w') as f:
            for gene in sorted(list(set(found_entrez))):
                f.write('%s\n' % gene) 

#### Subset patients

In [None]:
if analyze_shap_values:

    # iterate over patients
    for sample in samples:
        
        # get features with non-negative shap_values
        sig_features = [feature for feature in merged_features if shap_value_abs.at[sample,feature]>0]
        sig_values = [shap_value_abs.at[sample,x] for x in sig_features]
        sort_index = np.argsort(sig_values)[::-1]
        sig_features = [sig_features[i] for i in sort_index]
        sig_values = [sig_values[i] for i in sort_index]
        sig_values_cumsum = np.cumsum(sig_values)
        sig_values_cumsum /= np.max(sig_values_cumsum)
        
        # top N features - mean
        N = len([x for x in sig_values_cumsum if x<=0.95])
        top_N_features = [x.split(' # ')[-1] for x in sig_features[:N]]
        
        # gene lists for gene ontology - significant genes
        found_entrez = []
        stepsize = 400
        for i in range(int(np.ceil(len(top_N_features)/stepsize))):
            time.sleep(1)
            if i < len(top_N_features)/stepsize:
                gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
            else:
                gene_list = ','.join(top_N_features[(stepsize*i):])
            link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
            data = json.loads(urllib.request.urlopen(link).read())  
            for j in range(len(data)):
                if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                    found_entrez.append(str(data[j]['Gene ID']))
        with open('%s/%s/%s/shap_values/summary/patient/%s_sig_genes.txt' % (folder_name, dataset_name, output_folder, sample),'w') as f:
            for gene in sorted(list(set(found_entrez))):
                f.write('%s\n' % gene) 

#### Patient shap plot

In [None]:
# plotting function
def plot_patient_example(patient_name, actual_value, expected_value, features_, shap_, original_):
    
    # initialize figure
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(20,.15*len(features_)))
    plt.xticks([])
    plt.yticks([])
    
    # patient name
    if actual_value == 0:
        plt.text(0.5,0.8,patient_name, horizontalalignment='center', verticalalignment='center', color='#5bc2ae', fontsize=18, weight='bold')
    else:
        plt.text(0.5,1,patient_name, horizontalalignment='center', verticalalignment='center', color='#d93f20', fontsize=18, weight='bold')
    
    # number line
    plt.plot([0.18,0.23],[0,0],'k-',linewidth=1)
    plt.plot([0.25,1],[0,0],'k-',linewidth=1)
    plt.plot([0.22,0.24],[-0.3,0.3],'k-',linewidth=1)
    plt.plot([0.24,0.26],[-0.3,0.3],'k-',linewidth=1)
    for x in [0.18,0.5,1]:
        plt.plot([x,x],[-0.1,0.1],'k-',linewidth=1)
    plt.plot([expected_value,expected_value],[-0.3,0.1],'k--',linewidth=1)
    plt.text(0.18,0.1,'0%', horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16)
    plt.text(0.5,0.1,'50%', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(1,0.1,'100%', horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16)
    plt.text(expected_value,0.1,'%0.1f%%' % (expected_value*100), horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(expected_value,0.75,'Prior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    
    # data
    current_value = expected_value
    y_value = -0.3
    for i in range(len(features_)):
        if not pd.isna(original_[i]):
            if shap_[i] < 0:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#5bc2ae')
                if type(original_[i]) == str:
                    plt.text(current_value+0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='left', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value+0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='left', verticalalignment='center', fontsize=10)
            else:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#d93f20')
                if type(original_[i]) == str:
                    plt.text(current_value-0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='right', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value-0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=10)
            current_value += shap_[i]
            y_value += -0.3
        
    # end line
    plt.plot([current_value,current_value],[y_value+0.15,0.1],'k--',linewidth=1)
    if current_value < 0.5:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16, weight='bold')
    else:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16, weight='bold')
    plt.text(current_value,0.75,'Posterior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
        
    # limits
    plt.xlim(-0.02,1.02)
    plt.ylim(y_value-0.05, 0.8)
    plt.axis('off')

In [None]:
# example: TCGA-DU-8165-01A
for a,sample in enumerate(samples):
    if sample=='TCGA-DU-8165-01A':
    
        # get nonzero shap and feature values
        features_ = [feature.split(' # ')[-1] for feature in merged_features if shap_value.at[sample,feature] != 0]
        shap_ = shap_value.loc[sample][['gene_all # %s' % x for x in features_]].tolist()
        original_ = X_matrix.loc[sample][['gene_all # %s' % x for x in features_]].tolist()

        # convert numerical values to strings
        for j in range(len(original_)):
            if not type(original_[j]) == str:
                if type(original_[j]) == bool:
                    original_[j] = str(original_[j])
                elif not np.isnan(original_[j]):
                    if original_[j] == int(original_[j]):
                        original_[j] = str(int(original_[j]))
                    else:
                        original_[j] = str(np.round(original_[j],2))+' TPM'

        # order features based on absolute shap value
        sort_index = np.argsort(np.abs(shap_))[::-1]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        # create figure
        plot_patient_example(sample, y_vector[a], expected_value[a], features_, shap_, original_)
        plt.savefig('%s/%s/%s/shap_values/patient_example.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
        plt.close()

In [None]:
# plotting function
def plot_patient(patient_name, actual_value, expected_value, features_, shap_, original_):
    
    # initialize figure
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(20,.15*len(features_)))
    plt.xticks([])
    plt.yticks([])
    
    # patient name
    if actual_value == 0:
        plt.text(0.5,0.8,patient_name, horizontalalignment='center', verticalalignment='center', color='#5bc2ae', fontsize=18, weight='bold')
    else:
        plt.text(0.5,1,patient_name, horizontalalignment='center', verticalalignment='center', color='#d93f20', fontsize=18, weight='bold')
    
    # number line
    plt.plot([0,1],[0,0],'k-',linewidth=1)
    for x in [0,0.5,1]:
        plt.plot([x,x],[-0.1,0.1],'k-',linewidth=1)
    plt.plot([expected_value,expected_value],[-0.3,0.1],'k--',linewidth=1)
    plt.text(0,0.1,'0%', horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16)
    plt.text(0.5,0.1,'50%', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(1,0.1,'100%', horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16)
    plt.text(expected_value,0.1,'%0.1f%%' % (expected_value*100), horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(expected_value,0.75,'Prior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    
    # data
    current_value = expected_value
    y_value = -0.3
    for i in range(len(features_)):
        if not pd.isna(original_[i]):
            if shap_[i] < 0:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#5bc2ae')
                if type(original_[i]) == str:
                    plt.text(current_value+0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='left', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value+0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='left', verticalalignment='center', fontsize=10)
            else:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#d93f20')
                if type(original_[i]) == str:
                    plt.text(current_value-0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='right', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value-0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=10)
            current_value += shap_[i]
            y_value += -0.3
        
    # end line
    plt.plot([current_value,current_value],[y_value+0.15,0.1],'k--',linewidth=1)
    if current_value < 0.5:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16, weight='bold')
    else:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16, weight='bold')
    plt.text(current_value,0.75,'Posterior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
        
    # limits
    plt.xlim(-0.02,1.02)
    plt.ylim(y_value-0.05, 0.8)
    plt.axis('off')

In [None]:
# iterate over patients
for a,sample in enumerate(samples):
    if y_vector[a]==1:
    
        # get nonzero shap and feature values
        features_ = [feature.split(' # ')[-1] for feature in merged_features if shap_value.at[sample,feature] != 0]
        shap_ = shap_value.loc[sample][['gene_all # %s' % x for x in features_]].tolist()
        original_ = X_matrix.loc[sample][['gene_all # %s' % x for x in features_]].tolist()

        # convert numerical values to strings
        for j in range(len(original_)):
            if not type(original_[j]) == str:
                if type(original_[j]) == bool:
                    original_[j] = str(original_[j])
                elif not np.isnan(original_[j]):
                    if original_[j] == int(original_[j]):
                        original_[j] = str(int(original_[j]))
                    else:
                        original_[j] = str(original_[j])             

        # order features based on absolute shap value
        sort_index = np.argsort(np.abs(shap_))[::-1]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        # create figure
        plot_patient(sample, y_vector[a], expected_value[a], features_, shap_, original_)
        plt.savefig('%s/%s/%s/shap_values/patients/%s.png' % (folder_name, dataset_name, output_folder, sample), bbox_inches='tight', dpi=400)
        plt.close()

### Stacker

In [None]:
if analyze_stacker:
    
    # load stacker data from each split
    stacker_samples = []
    stacker_weights_ = []
    stacker_performance = []
    for i in range(n_splits):
        with open('%s/%s/stacker_%d.pickle' % (folder_name, dataset_name, i+1) ,'rb') as f:
            #X_test, X_test_all, y_test, test_predictions, testing_pred, weights = pickle.load(f)
            X_test, X_test_all, y_test, y_pred, test_predictions, weights, test_best_classifier, bst = pickle.load(f)
            
        for j,sample in enumerate(X_test[0].index.tolist()):
            if sample in stacker_samples:
                for k,dataset in enumerate(individual_datasets):
                    stacker_weights_[stacker_samples.index(sample)][k].append(weights[j,k])
                stacker_performance[stacker_samples.index(sample)].append(stacker_performance_weights[i])
            else:
                stacker_samples.append(sample)
                stacker_weights_.append([])
                for k,dataset in enumerate(individual_datasets):
                    stacker_weights_[-1].append([weights[j,k]])
                stacker_performance.append([stacker_performance_weights[i]])
                
    # weighted average of stacker predictions
    stacker_weights = []
    for i in range(len(stacker_samples)):
        stacker_weights.append([])
        for j in range(len(individual_datasets)):
            stacker_weights[-1].append(np.average(stacker_weights_[i][j], weights=stacker_performance[i]))
    
    # sort based on shap values
    stacker_weights_ = []
    for sample in samples:
        stacker_weights_.append(stacker_weights[stacker_samples.index(sample)])
    stacker_weights = stacker_weights_.copy()
            
    # dataframe
    df = pd.DataFrame(columns=['Dataset','Value','Color'])
    for i in range(len(individual_datasets)):
        for j in range(len(samples)):
            if distribution_labels[j]==0:
                df.loc[df.shape[0]] = [i,stacker_weights[j][i],'red']
            else:
                df.loc[df.shape[0]] = [i,stacker_weights[j][i],'black']
    stacker_weights_grouped = []
    for i in range(len(individual_datasets)):
        stacker_weights_grouped.append([x[i] for x in stacker_weights])

In [None]:
# optimal threshold
values = stacker_weights_grouped[0]
ids = [1 if distribution_labels[i]==0 else 0 for i in range(len(distribution_labels))]
sort_index = np.argsort(values)
values = [values[i] for i in sort_index]
ids = [ids[i] for i in sort_index]
thresh = []
acc = []
for i in range(1,len(ids)):
    thresh.append(np.mean([values[i-1],values[i]]))
    tp = len([i for i in range(len(ids)) if values[i]<thresh[-1] and ids[i]==1])
    tn = len([i for i in range(len(ids)) if values[i]>thresh[-1] and ids[i]==0])
    fp = len([i for i in range(len(ids)) if values[i]<thresh[-1] and ids[i]==0])
    fn = len([i for i in range(len(ids)) if values[i]>thresh[-1] and ids[i]==1])
    acc.append((tp+tn)/(tp+tn+fp+fn))
max_acc = np.max(acc)
max_thresh = thresh[np.argmax(acc)]

In [None]:
# plot of stacker weights
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=24)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(3*len(individual_datasets),10))
ax = fig.add_subplot(111)
fig = sns.boxplot(data=stacker_weights_grouped, whis=1, fliersize=0, palette=colors)
fig = sns.swarmplot(data=df, x='Dataset', y='Value', hue='Color', palette=['black','red'], alpha=0.25)
plt.plot([-1,4],[max_thresh,max_thresh],'k--')
plt.text(0.5,max_thresh+0.015,'Optimal Threshold', fontsize=20)
plt.text(0.5,max_thresh-0.04,'Accuracy: %0.1f%%' % (max_acc*100), fontsize=20)
plt.xlim([-0.5,len(individual_datasets)-0.5])
fig.set(xticklabels=individual_datasets_labels)
fig.spines['right'].set_visible(False)
fig.spines['top'].set_visible(False)
plt.ylabel('Meta Learner Weight', fontsize=24)
plt.ylim([0,1])
plt.yticks([0,0.2,0.4,0.6,0.8,1,max_thresh],['0%','20%','40%','60%','80%','100%','%0.1f%%' % (max_thresh*100)])
plt.savefig('%s/%s/%s/stacker/summary/weights.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

#### Clinical factors differentiating stacker groups

In [None]:
# get factors
X_stacker = X_matrix.loc[samples]
numerical_factors = ['clinical # %s' % x for x in ['AGE','PACK YEARS','ALCOHOL DAYS PER WEEK','ALCOHOL PER DAY','KARNOFSKY']]
categorical_factors = [x for x in X_stacker.columns.tolist() if x.split(' # ')[0]=='clinical' and x not in numerical_factors]

In [None]:
# numerical factors
for factor in numerical_factors:
    
    # get values
    values = X_stacker[factor].values.tolist()
    stackval = stacker_weights_grouped[0]
    keep_index = [i for i,x in enumerate(values) if not pd.isna(x)]
    values = [values[i] for i in keep_index]
    stackval = [stackval[i] for i in keep_index]
    
    # correlation coefficient
    corr = pearsonr(values, stackval)
    #print(factor, corr)

In [None]:
# categorical factors
for factor in categorical_factors:
    
    # get values
    values = X_stacker[factor].values.tolist()
    stackval = stacker_weights_grouped[0]
    keep_index = [i for i,x in enumerate(values) if not pd.isna(x)]
    values = [values[i] for i in keep_index]
    stackval = [stackval[i] for i in keep_index]
    
    # ANOVA
    classes = [x for x in collections.Counter(values).keys()]
    val = []
    for k in range(len(classes)):
        val.append([x for l,x in enumerate(stackval) if values[l] == classes[k]])
    F = f_oneway(*val)
    #print(factor, F)

In [None]:
for factor in categorical_factors[0:1]:
    # plot factor
    #factor = 'clinical # RESPONSE DRUG Cisplatin'

    # values
    values = X_stacker[factor].values.tolist()
    stackval = stacker_weights_grouped[0]
    keep_index = [i for i,x in enumerate(values) if not pd.isna(x)]
    values = [values[i] for i in keep_index]
    stackval = [stackval[i] for i in keep_index]
    classes = sorted([x for x in collections.Counter(values).keys()])
    val = []
    for k in range(len(classes)):
        val.append([x for l,x in enumerate(stackval) if values[l] == classes[k]])
    keep_index = [i for i,x in enumerate(val) if len(x)>=3]
    classes = [classes[i] for i in keep_index]
    val = [val[i] for i in keep_index]
    sort_index = np.argsort([np.mean(x) for x in val])
    #sort_index = np.argsort(classes)
    classes = [classes[i] for i in sort_index]
    val = [val[i] for i in sort_index]
    if len(classes)>0:
        #if len([i for i,x in enumerate(val) if np.mean(x)<max_thresh])>0:
        if True:

            # plot
            plt.rcParams['font.family'] = 'Arial'
            fig = plt.figure(figsize=(len(val)/2,5))
            ax = fig.add_subplot(111)
            plt.rc('xtick', labelsize=18)
            plt.rc('ytick', labelsize=20)
            fig = sns.boxplot(data=val, whis=1, fliersize=0)
            fig = sns.swarmplot(data=val, color='black', alpha=0.25)
            plt.plot([-1,len(classes)],[max_thresh,max_thresh],'k--')
            #plt.plot([-1,len(classes)],[max_thresh1,max_thresh1],'k--')
            #plt.plot([-1,len(classes)],[max_thresh2,max_thresh2],'k--')
            plt.xlim([-0.5,len(val)-0.5])
            plt.xticks(list(range(len(classes))),classes, rotation=90)
            fig.spines['right'].set_visible(False)
            fig.spines['top'].set_visible(False)
            plt.ylabel('Meta Learner Weight - Clinical', fontsize=20)
            vals = ax.get_yticks()
            decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
            if len([x for x in decimals if x != '0']) > 0:
                number_of_places = np.max([len(x) for x in decimals])
            else:
                number_of_places = np.max([len(x) for x in decimals])-1
            if number_of_places == 0:
                ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
            elif number_of_places == 1:
                ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
            elif number_of_places == 2:
                ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
            else:
                raise Exception('Number of decimal places = %d' % number_of_places)
            plt.title(factor)
            plt.show()
            #plt.close()
            #plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_individual.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
            #plt.close()

#### Clinical factors differentiating clinical groups

In [None]:
# cohort
factor = 'clinical # COHORT'

# factor counts
factors = sorted([x for x in collections.Counter(X_matrix[factor].values.tolist()).keys() if len([y for y in X_matrix[factor].values.tolist() if y==x])>=1])
counts = np.array([[0.,0.,0.] for x in factors])
numbers = [0]*len(factors)
for i,sample in enumerate(samples):
    if X_matrix.at[sample,factor] in factors:
        counts[factors.index(X_matrix.at[sample,factor])][distribution_labels[i]] += 1
        numbers[factors.index(X_matrix.at[sample,factor])] += 1
    
# normalize factor counts
for i in range(len(counts)):
    counts[i] = counts[i]/np.sum(counts[i])
    
# heat map
plt.rcParams["font.family"] = "Arial"
fig = plt.figure(figsize=(11,2))
ax = sns.heatmap(counts.T, xticklabels=factors, yticklabels=['Low','Medium','High'], cmap=sns.diverging_palette(220, 10, n=1000), center=0.333, cbar=True)
#ax = sns.clustermap(counts.T, row_cluster=False, xticklabels=factors, yticklabels=['Low','Medium','High'], cmap=sns.diverging_palette(220, 10, n=1000), center=0.333, cbar=True)
cbar = ax.figure.colorbar(ax.collections[0])
cbar.set_ticks([0,0.25,0.5,0.75,1])
cbar.ax.tick_params(labelsize=16)
cbar.outline.set_visible(False)
cbar.set_label(r'Fraction of Patients', fontsize=16, labelpad=-20)
plt.yticks(rotation=0,fontsize=16)
plt.xticks(rotation=90,fontsize=16)
#plt.savefig('fva_subtype.png', bbox_inches='tight', dpi=400)
#plt.close()
plt.show()

In [None]:
# cohort
factor = 'clinical # COHORT'

# clinical values
factors = sorted([x for x in collections.Counter(X_matrix[factor].values.tolist()).keys() if len([y for y in X_matrix[factor].values.tolist() if y==x])>=1])
counts = [[] for factor in factors]
for i,sample in enumerate(samples):
    if X_matrix.at[sample,factor] in factors:
        counts[factors.index(X_matrix.at[sample,factor])].append(dataset_contribution_individual[0][i])

# sort
sort_index = np.argsort([np.median(x) for x in counts])
factors = [factors[i] for i in sort_index]
counts = [counts[i] for i in sort_index]
        
# thresholds
threshold1 = np.mean([np.max([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==0]),np.min([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==1])])
threshold2 = np.mean([np.max([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==1]),np.min([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==2])])

# colors
colors_ = ['#FFFACD','#F0E68C','#BFBB99']
colors = []
for i in range(len(counts)):
    if np.median(counts[i]) < threshold1:
        colors.append(colors_[0])
    elif np.median(counts[i]) < threshold2:
        colors.append(colors_[1])
    else:
        colors.append(colors_[2])

# plot
plt.rcParams['font.family'] = 'Arial'
fig = plt.figure(figsize=(len(factors)/2,5))
ax = fig.add_subplot(111)
plt.rc('xtick', labelsize=18)
plt.rc('ytick', labelsize=20)
fig = sns.boxplot(data=counts, whis=1, fliersize=0, palette=colors)
fig = sns.swarmplot(data=counts, color='black', alpha=0.25)
plt.plot([-1,len(factors)],[threshold1,threshold1],'k--')
plt.plot([-1,len(factors)],[threshold2,threshold2],'k--')
plt.text(len(factors)-0.1,threshold1/2,'Low Clinical', fontsize=16, verticalalignment='center')
plt.text(len(factors)-0.1,np.mean([threshold1,threshold2]),'Medium Clinical', fontsize=16, verticalalignment='center')
plt.text(len(factors)-0.1,np.mean([threshold2,1]),'High Clinical', fontsize=16, verticalalignment='center')
plt.xlim([-0.5,len(factors)-0.25])
plt.xticks(list(range(len(factors))),factors, rotation=90)
fig.spines['right'].set_visible(False)
fig.spines['top'].set_visible(False)
plt.ylabel('Percent Contribution - Clinical', fontsize=20)
plt.ylim([0,1])
vals = ax.get_yticks()
decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
if len([x for x in decimals if x != '0']) > 0:
    number_of_places = np.max([len(x) for x in decimals])
else:
    number_of_places = np.max([len(x) for x in decimals])-1
if number_of_places == 0:
    ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
elif number_of_places == 1:
    ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
elif number_of_places == 2:
    ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
else:
    raise Exception('Number of decimal places = %d' % number_of_places)
plt.savefig('%s/%s/%s/shap_values/summary/clinical_cohort.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
factor_name = []
factor_f = []
factor_p = []

for factor in categorical_factors:

    # clinical values
    factors = sorted([x for x in collections.Counter(X_matrix[factor].values.tolist()).keys() if len([y for y in X_matrix[factor].values.tolist() if y==x])>=1])
    counts = [[] for factor in factors]
    for i,sample in enumerate(samples):
        if X_matrix.at[sample,factor] in factors:
            counts[factors.index(X_matrix.at[sample,factor])].append(dataset_contribution_individual[0][i])

    # sort
    sort_index = np.argsort([np.median(x) for x in counts])
    factors = [factors[i] for i in sort_index]
    counts = [counts[i] for i in sort_index]

    # thresholds
    threshold1 = np.mean([np.max([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==0]),np.min([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==1])])
    threshold2 = np.mean([np.max([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==1]),np.min([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==2])])
    
    # number in each group
    groups = []
    for i in range(len(factors)):
        groups.append([0,0,0])
        for j in range(len(counts[i])):
            if counts[i][j] <= threshold1:
                groups[-1][0] += 1
            elif counts[i][j] <= threshold2:
                groups[-1][1] += 1
            else:
                groups[-1][2] += 1
    
    # chi square test
    try:
        f,p,dof,expected = chi2_contingency(np.array(groups))
        factor_name.append(factor)
        factor_f.append(f)
        factor_p.append(p)
    except:
        factor_name.append(factor)
        factor_f.append(0)
        factor_p.append(1)
    
    # plot
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(len(factors)/2,5))
    ax = fig.add_subplot(111)
    plt.rc('xtick', labelsize=18)
    plt.rc('ytick', labelsize=20)
    fig = sns.boxplot(data=counts, whis=1, fliersize=0)
    fig = sns.swarmplot(data=counts, color='black', alpha=0.25)
    plt.plot([-1,len(factors)],[threshold1,threshold1],'k--')
    plt.plot([-1,len(factors)],[threshold2,threshold2],'k--')
    plt.text(len(factors)-0.1,threshold1/2,'Low Clinical', fontsize=16, verticalalignment='center')
    plt.text(len(factors)-0.1,np.mean([threshold1,threshold2]),'Medium Clinical', fontsize=16, verticalalignment='center')
    plt.text(len(factors)-0.1,np.mean([threshold2,1]),'High Clinical', fontsize=16, verticalalignment='center')
    plt.xlim([-0.5,len(factors)-0.5])
    plt.xticks(list(range(len(factors))),factors, rotation=90)
    fig.spines['right'].set_visible(False)
    fig.spines['top'].set_visible(False)
    plt.ylabel('Percent Contribution - Clinical', fontsize=20)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.title(factor)
    #plt.show()
    plt.close()
    #plt.savefig('%s/%s/%s/shap_values/summary/dataset_contribution_individual.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    #plt.close()

In [None]:
# sort based on p value
sort_index = np.argsort(factor_p)
sort_index = [x for x in sort_index if factor_p[x]<=0.05]
factor_name = [factor_name[i] for i in sort_index]
factor_f = [factor_f[i] for i in sort_index]
factor_p = [factor_p[i] for i in sort_index]

In [None]:
factor_name

In [None]:
factor_name = ['Cancer Type','Histology','Pathologic T','Pathologic Stage','Clinical T','Temozolomide response','Clinical Stage','Pathologic N','Patient gender','Clinical N','Tumor location','Anastrozole response','Clinical grade','Carboplatin response']

In [None]:
# bar plot
plt.rcParams['font.family'] = 'Arial'
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=20)
fig = plt.figure(figsize=(8,5))
ax = fig.add_subplot(111)
barlist = plt.bar(range(len(factor_name)), [-np.log10(x) for x in factor_p])
plt.plot([-1,len(factor_name)],[-np.log10(0.05),-np.log10(0.05)],'k--')
for i in range(len(barlist)):
    barlist[i].set_color('#FFD700')
plt.xticks(range(len(factor_name)), factor_name, rotation=90, fontsize=14)
plt.xlim(-1,len(factor_name))
plt.yticks([1.301,20,40,60],labels=[1.301,20,40,60])
plt.ylabel(r'-log$_{10}$(p-value)', fontsize=22)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
#plt.show()
plt.savefig('%s/%s/%s/shap_values/summary/clinical_factor.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
plt.close()

In [None]:
with open('%s/%s/%s/shap_values/summary/clinical_factor.csv' % (folder_name, dataset_name, output_folder),'w') as f:
    f.write('FACTOR,F,p\n')
    for i in range(len(factor_name)):
        f.write('%s,%f,%.20f\n' % (factor_name[i].split('clinical # ')[-1], factor_f[i], factor_p[i]))
        #print('%s,%f,%.20f' % (factor_name[i].split('clinical # ')[-1], factor_f[i], factor_p[i]))

In [None]:
# cohort
factor = 'clinical # HISTOLOGIC'
cohort = 'BRCA'

# clinical values
factors = sorted([x for x in collections.Counter(X_matrix[factor].values.tolist()).keys() if len([y for y in X_matrix[factor].values.tolist() if y==x])>=1])
factors = [x for x in factors if x.split(' > ')[0]==cohort]
if len(factors)>=2:
    counts = [[] for factor in factors]
    for i,sample in enumerate(samples):
        if X_matrix.at[sample,factor] in factors:
            counts[factors.index(X_matrix.at[sample,factor])].append(dataset_contribution_individual[0][i])

    # sort
    sort_index = np.argsort([np.mean(x) for x in counts])
    factors = [factors[i] for i in sort_index]
    counts = [counts[i] for i in sort_index]
    
    # groups to keep
    keep_index = [0,3,4,5]
    factors = [factors[i] for i in keep_index]
    counts = [counts[i] for i in keep_index]
    factors = [x.split(' > ')[1] for x in factors]
    
    # thresholds
    threshold1 = np.mean([np.max([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==0]),np.min([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==1])])
    threshold2 = np.mean([np.max([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==1]),np.min([x for i,x in enumerate(dataset_contribution_individual[0]) if distribution_labels[i]==2])])

    # colors
    colors_ = ['#FFFACD','#F0E68C','#BFBB99']
    colors = []
    for i in range(len(counts)):
        if np.median(counts[i]) < threshold1:
            colors.append(colors_[0])
        elif np.median(counts[i]) < threshold2:
            colors.append(colors_[1])
        else:
            colors.append(colors_[2])

    # plot
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(len(factors)/2,5))
    ax = fig.add_subplot(111)
    plt.rc('xtick', labelsize=18)
    plt.rc('ytick', labelsize=20)
    fig = sns.boxplot(data=counts, whis=1, fliersize=0, palette=colors)
    fig = sns.swarmplot(data=counts, color='black', alpha=0.25)
    plt.plot([-1,len(factors)],[threshold1,threshold1],'k--')
    plt.plot([-1,len(factors)],[threshold2,threshold2],'k--')
    plt.text(len(factors)-0.1,threshold1/2,'Low Clinical', fontsize=16, verticalalignment='center')
    plt.text(len(factors)-0.1,np.mean([threshold1,threshold2]),'Medium Clinical', fontsize=16, verticalalignment='center')
    plt.text(len(factors)-0.1,np.mean([threshold2,1]),'High Clinical', fontsize=16, verticalalignment='center')
    plt.xlim([-0.5,len(factors)-0.5])
    plt.xticks(list(range(len(factors))),factors, rotation=90)
    fig.spines['right'].set_visible(False)
    fig.spines['top'].set_visible(False)
    plt.ylabel('Percent Contribution - Clinical', fontsize=20)
    plt.ylim([0,1])
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/clinical_BRCA.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()