In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pickle
import random
import seaborn as sns
from scipy.stats import f_oneway
import urllib
import json
import time
from scipy.stats import ttest_ind, spearmanr

from sklearn.metrics import pairwise_distances
from sklearn.cluster import KMeans

### Input

Folder and datset names

In [None]:
folder_name = 'withhyperopt_2'
individual_datasets = ['gene_all']
individual_datasets_labels = ['Gene Expression']
colors = ['#3DAEC5']
n_splits = 20

Output folder name

In [None]:
output_folder = '_shap'

Analysis options

In [None]:
analyze_stacker = True
analyze_shap_values = True
analyze_shap_interactions = True

### Create output folders

In [None]:
# name of dataset
dataset_name = '+'.join(individual_datasets)

# output folders
if not os.path.isdir('%s/%s/%s' % (folder_name, dataset_name, output_folder)):
    os.mkdir('%s/%s/%s' % (folder_name, dataset_name, output_folder))
    
    # stacker
    if analyze_stacker:
        os.mkdir('%s/%s/%s/stacker' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/stacker/summary' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/stacker/clinical_features' % (folder_name, dataset_name, output_folder))
        
    # shap values
    if analyze_shap_values:
        os.mkdir('%s/%s/%s/shap_values' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/summary' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/summary/cohort' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/features' % (folder_name, dataset_name, output_folder))
        os.mkdir('%s/%s/%s/shap_values/patients' % (folder_name, dataset_name, output_folder))
    
    # shap interactions
    if analyze_shap_interactions:
        os.mkdir('%s/%s/%s/shap_interactions' % (folder_name, dataset_name, output_folder))

### Load dataset

In [None]:
# create merged dataset
with open('_datasets/%s.pickle' % individual_datasets[0], 'rb') as f:
    X_matrix, y_vector, categorical_conversion_old = pickle.load(f, encoding='latin1')
X_matrix.columns = ['%s # %s' % (individual_datasets[0], x) for x in X_matrix.columns.tolist()]
categorical_conversion = {}
for key in categorical_conversion_old:
    categorical_conversion['%s # %s' % (individual_datasets[0], key)] = categorical_conversion_old[key]
for b in range(1,len(individual_datasets)):
    with open('_datasets/%s.pickle' % individual_datasets[b], 'rb') as f:
        X_matrix_, y_vector_, categorical_conversion_old = pickle.load(f, encoding='latin1')
    X_matrix_.columns = ['%s # %s' % (individual_datasets[b], x) for x in X_matrix_.columns.tolist()]
    categorical_conversion_ = {}
    for key in categorical_conversion_old:
        categorical_conversion_['%s # %s' % (individual_datasets[b], key)] = categorical_conversion_old[key]
    if X_matrix.index.tolist() == X_matrix_.index.tolist():
        X_matrix = pd.concat([X_matrix, X_matrix_], axis=1)
        categorical_conversion = {**categorical_conversion, **categorical_conversion_}
    else:
        raise Exception("Dataset sample lists don't match")
samples = X_matrix.index.tolist()
features =  X_matrix.columns.tolist()

# list of merged features
if len(categorical_conversion) > 0:
    merged_features = []
    for feature in features:
        if feature.split(' | ')[0] not in categorical_conversion:
            merged_features.append(feature)
        elif feature.split(' | ')[0] not in merged_features:
            merged_features.append(feature.split(' | ')[0])
else:
    merged_features = features.copy()
    
# merge feature matrix
X_matrix_ = X_matrix.copy()
for feature in categorical_conversion:
    X_matrix_[feature] = np.nan
    X_matrix_[feature] = X_matrix_[feature].astype(object)
    
    # get values for categorical features
    values = []
    for sample in X_matrix.index.tolist():
        for val in categorical_conversion[feature]:
            if X_matrix.at[sample,'%s | %s' % (feature, val)] == 1:
                X_matrix_.at[sample, feature] = val
X_matrix = X_matrix_[merged_features]

### Load performance

In [None]:
# classifier - weighted log loss
df = pd.read_csv('%s/weightedlogloss.csv' % folder_name, index_col=0)
classifier_weighted_log_loss = df.loc[['split_%d' % (i+1) for i in range(n_splits)]][dataset_name].values.tolist()
classifier_performance_weights = [1/x for x in classifier_weighted_log_loss]

### SHAP values

#### Load data

In [None]:
if analyze_shap_values:
        
    # load shap values
    shap_expected = []
    shap_values = []
    for i in range(n_splits):
        with open('%s/%s/shap_values_%d.pickle' % (folder_name, dataset_name, i+1) ,'rb') as f:
            shap_expected_, shap_values_ = pickle.load(f)
            shap_expected.append(shap_expected_)
            shap_values.append(shap_values_)
            
    # remove samples that aren't analyzed
    keep_index = []
    for i,sample in enumerate(samples):
        to_keep = False
        for j in range(n_splits):
            if sample in shap_values[j].index.tolist():
                to_keep = True
        if to_keep:
            keep_index.append(i)
    samples = [samples[i] for i in keep_index]
    y_vector = [y_vector[i] for i in keep_index]
    X_matrix = X_matrix.iloc[keep_index]
    
    # combine expected values
    _expected_value_sum = [0 for sample in samples]
    _shap_weight = [0 for sample in samples]
    for i,sample in enumerate(samples):
        for j in range(n_splits):
            if sample in shap_values[j].index.tolist():
                _expected_value_sum[i] += classifier_performance_weights[j] * shap_expected[j][shap_values[j].index.tolist().index(sample)]
                _shap_weight[i] += classifier_performance_weights[j]
    expected_value = np.divide(_expected_value_sum, _shap_weight)
    
    # combine shap values
    for i in range(n_splits):
        samples_not_included = [x for x in samples if x not in shap_values[i].index.tolist()]
        shap_values[i] = pd.concat([shap_values[i], pd.DataFrame(data=0, index=samples_not_included, columns=shap_values[i].columns.tolist())])
        shap_values[i] = shap_values[i].loc[samples]
        features_not_included = [x for x in merged_features if x not in shap_values[i].columns.tolist()]
        shap_values[i] = pd.concat([shap_values[i], pd.DataFrame(data=0, index=samples, columns=features_not_included)], axis=1, sort=False)
        shap_values[i] = shap_values[i][merged_features]
    shap_value = pd.DataFrame(data=0, index=samples, columns=merged_features)
    for i in range(n_splits):
        shap_value = np.add(shap_value, classifier_performance_weights[i]*shap_values[i])
    for j in range(len(samples)):
        shap_value.loc[samples[j]] /= _shap_weight[j]
    
    # calculate absolute shap values
    shap_value_abs = shap_value.abs()
    
    # zero shap values for imputed features
    for sample in shap_value_abs.index.tolist():
        for feature in shap_value_abs.columns.tolist():
            if pd.isna(X_matrix.at[sample,feature]):
                shap_value_abs.at[sample,feature] = 0
    
    # normalize by difference between prior and posterior
    for i,sample in enumerate(samples):
        shap_value_abs.loc[sample] /= np.sum(shap_value_abs.loc[sample].values.tolist())
        
    # mean and standard error for each feature
    shap_value_abs_mean = shap_value_abs.mean(axis=0).values.tolist()
    shap_value_abs_mean_cumsum = np.cumsum(sorted(shap_value_abs_mean)[::-1])
    shap_value_abs_std = shap_value_abs.std(axis=0).values.tolist()
    shap_value_abs_sterr = [x/np.sqrt(len(samples)) for x in shap_value_abs_std]

#### All features

In [None]:
if analyze_shap_values:
    
    # genes with evidence
    evidence = ['CDH10','TSC22D4','SEMA5B','SCN7A','PTCHD1','OLIG1','HIF3A','IL19','BNIP3','WIPI1']
    
    # top N features - mean
    top_N_index = np.argsort(shap_value_abs_mean)[::-1]
    top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
    top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
    top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    
    # plot - all shap values
    plt.rcParams['font.family'] = 'Arial'
    plt.rc('xtick', labelsize=18)
    plt.rc('ytick', labelsize=18)
    fig = plt.figure(figsize=(10,5))
    ax = fig.add_subplot(111)
    ax.set_xscale('log')
    ax.plot(range(1,len(merged_features)+1), sorted(shap_value_abs_mean)[::-1], 'k.-', markersize=5)
    for i in range(len(evidence)):
        ax.plot(top_N_features.index(evidence[i])+1, top_N_mean[top_N_features.index(evidence[i])], '.', markersize=5, color=colors[0])
    plt.ylabel(r'Mean |$\Delta$P|', fontsize=20)
    plt.ylim(-0.0001,0.03)
    plt.yticks([0,0.01,0.02,0.03])
    plt.xlabel('Features', fontsize=20)
    ax2 = ax.twinx()
    ax2.plot(range(1,len(merged_features)+1), shap_value_abs_mean_cumsum, '.-', color='#808080', markersize=5)
    ax2.plot([len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),len([x for x in shap_value_abs_mean_cumsum if x<=0.95])],[0,0.95],'--',color='#808080')
    ax2.plot([len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),1.1*len(merged_features)],[0.95,0.95],'--',color='#808080')
    ax2.plot([len([x for x in shap_value_abs_mean if x>0]),len([x for x in shap_value_abs_mean if x>0])],[0,1],'--',color='#808080')
    ax2.tick_params(axis='y', labelcolor='#808080')
    ax2.set_ylabel('Cumulative Sum', color='#808080', fontsize=20)
    ax2.set_ylim(0,1.01)
    ax2.set_yticks([0,0.25,0.5,0.75,0.95,1])
    num_int = len(str(len(merged_features)))
    plt.xticks([10**x for x in range(num_int)][:-1]+[len(merged_features)],[10**x for x in range(num_int)][:-1]+[len(merged_features)])
    plt.xlim(0.95,1.05*len(merged_features))  
    plt.text(len([x for x in shap_value_abs_mean_cumsum if x<=0.95]),0.03,'%d '% len([x for x in shap_value_abs_mean_cumsum if x<=0.95]), weight='bold', ha='right', fontsize=20)
    plt.text(len([x for x in shap_value_abs_mean if x>0]),0.03,'%d' % len([x for x in shap_value_abs_mean if x>0]), ha='left', fontsize=20)
    #plt.text(len(merged_features),0.03,'%d' % len(merged_features), weight='bold', ha='right', fontsize=20)
    vals = ax.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    vals = ax2.get_yticks()
    decimals = [str(round(x*100,6)).split('.')[1] for x in vals]
    if len([x for x in decimals if x != '0']) > 0:
        number_of_places = np.max([len(x) for x in decimals])
    else:
        number_of_places = np.max([len(x) for x in decimals])-1
    if number_of_places == 0:
        ax2.set_yticklabels(['{:,.0%}'.format(x) for x in vals])
    elif number_of_places == 1:
        ax2.set_yticklabels(['{:,.1%}'.format(x) for x in vals])
    elif number_of_places == 2:
        ax2.set_yticklabels(['{:,.2%}'.format(x) for x in vals])
    else:
        raise Exception('Number of decimal places = %d' % number_of_places)
    plt.savefig('%s/%s/%s/shap_values/summary/all_features.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
    plt.close()


In [None]:
# gene resulting in increased or decreased probability with increased expression
increase = []
decrease = []
for feature in overall_top_features:
    
    # expression and shap values
    expression = X_matrix['gene_all # %s' % feature].values.tolist()
    shapvalue = shap_value['gene_all # %s' % feature].values.tolist()
    keep_index = [i for i in range(len(expression)) if (not pd.isna(expression[i])) and (not pd.isna(shapvalue[i]))]
    expression = [expression[i] for i in keep_index]
    shapvalue = [shapvalue[i] for i in keep_index]
    
    # correlation coefficient
    r = spearmanr(expression, shapvalue).correlation
    if r > 0:
        increase.append(feature)
    else:
        decrease.append(feature)

#### Entrez gene lists for gene ontology

In [None]:
if analyze_shap_values:

    # gene lists for gene ontology - significant genes
    N = len([x for x in shap_value_abs_mean_cumsum if x<=0.95])
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    overall_top_features = top_N_features.copy()
    found_entrez = []
    stepsize = 400
    for i in range(int(np.ceil(len(top_N_features)/stepsize))):
        time.sleep(1)
        if i < len(top_N_features)/stepsize:
            gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
        else:
            gene_list = ','.join(top_N_features[(stepsize*i):])
        link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
        data = json.loads(urllib.request.urlopen(link).read())  
        for j in range(len(data)):
            if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                found_entrez.append(str(data[j]['Gene ID']))
    with open('%s/%s/%s/shap_values/summary/sig_genes.txt' % (folder_name, dataset_name, output_folder),'w') as f:
        for gene in sorted(list(set(found_entrez))):
            f.write('%s\n' % gene)
            
    # gene lists for gene ontology - significant increase genes
    found_entrez = []
    stepsize = 400
    for i in range(int(np.ceil(len(increase)/stepsize))):
        time.sleep(1)
        if i < len(increase)/stepsize:
            gene_list = ','.join(increase[(stepsize*i):(stepsize*(i+1))])
        else:
            gene_list = ','.join(increase[(stepsize*i):])
        link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
        data = json.loads(urllib.request.urlopen(link).read())  
        for j in range(len(data)):
            if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                found_entrez.append(str(data[j]['Gene ID']))
    with open('%s/%s/%s/shap_values/summary/sig_genes_increase.txt' % (folder_name, dataset_name, output_folder),'w') as f:
        for gene in sorted(list(set(found_entrez))):
            f.write('%s\n' % gene)
            
    # gene lists for gene ontology - significant increase genes
    found_entrez = []
    stepsize = 400
    for i in range(int(np.ceil(len(decrease)/stepsize))):
        time.sleep(1)
        if i < len(decrease)/stepsize:
            gene_list = ','.join(decrease[(stepsize*i):(stepsize*(i+1))])
        else:
            gene_list = ','.join(decrease[(stepsize*i):])
        link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
        data = json.loads(urllib.request.urlopen(link).read())  
        for j in range(len(data)):
            if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                found_entrez.append(str(data[j]['Gene ID']))
    with open('%s/%s/%s/shap_values/summary/sig_genes_decrease.txt' % (folder_name, dataset_name, output_folder),'w') as f:
        for gene in sorted(list(set(found_entrez))):
            f.write('%s\n' % gene)
            
    # gene lists for gene ontology - reference genes
    N = len(merged_features)
    top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
    top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
    found_entrez = []
    stepsize = 400
    for i in range(int(np.ceil(len(top_N_features)/stepsize))):
        time.sleep(1)
        if i < len(top_N_features)/stepsize:
            gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
        else:
            gene_list = ','.join(top_N_features[(stepsize*i):])
        link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
        data = json.loads(urllib.request.urlopen(link).read())  
        for j in range(len(data)):
            if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                found_entrez.append(str(data[j]['Gene ID']))
    with open('%s/%s/%s/shap_values/summary/ref_genes.txt' % (folder_name, dataset_name, output_folder),'w') as f:
        for gene in sorted(list(set(found_entrez))):
            f.write('%s\n' % gene)

#### Subset cohorts

In [None]:
if analyze_shap_values:
    
    # load cohort for each sample
    with open('_datasets/clinical.pickle', 'rb') as f:
        X_matrix_, y_vector_, categorical_conversion_ = pickle.load(f, encoding='latin1')
    cohorts = []
    for sample in X_matrix_.index.tolist():
        for cohort in categorical_conversion_['COHORT']:
            if X_matrix_.at[sample,'COHORT | %s' % cohort] == 1:
                cohorts.append(cohort)

    # iterate over datasets
    for cohort in ['ACC', 'BLCA', 'BRCA', 'CESC', 'COAD', 'DLBC', 'GBM', 'HNSC', 'LGG', 'LIHC', 'LUAD', 'LUSC', 'PRAD', 'READ', 'SKCM', 'STAD', 'THCA', 'UCEC', 'UCS']:
        
        # subset patients within cohort
        samples_ = [samples[i] for i in range(len(samples)) if cohorts[i]==cohort]
        shap_value_abs_ = shap_value_abs.loc[samples_]
        
        # mean and standard error for each feature
        shap_value_abs_mean = shap_value_abs_.mean(axis=0).values.tolist()
        shap_value_abs_mean_cumsum = np.cumsum(sorted(shap_value_abs_mean)[::-1])
        shap_value_abs_std = shap_value_abs_.std(axis=0).values.tolist()
        shap_value_abs_sterr = [x/np.sqrt(len(samples)) for x in shap_value_abs_std]
        
        # top N features - mean
        N = len([x for x in shap_value_abs_mean_cumsum if x<=0.95])
        top_N_index = np.argsort(shap_value_abs_mean)[::-1][:N]
        top_N_mean = [shap_value_abs_mean[i] for i in top_N_index]
        top_N_sterr = [shap_value_abs_sterr[i] for i in top_N_index]
        top_N_datasets = [merged_features[i].split(' # ')[0] for i in top_N_index]
        top_N_features = [merged_features[i].split(' # ')[-1] for i in top_N_index]
        
        # gene lists for gene ontology - significant genes
        found_entrez = []
        stepsize = 400
        for i in range(int(np.ceil(len(top_N_features)/stepsize))):
            time.sleep(1)
            if i < len(top_N_features)/stepsize:
                gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
            else:
                gene_list = ','.join(top_N_features[(stepsize*i):])
            link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
            data = json.loads(urllib.request.urlopen(link).read())  
            for j in range(len(data)):
                if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                    found_entrez.append(str(data[j]['Gene ID']))
        with open('%s/%s/%s/shap_values/summary/cohort/%s_sig_genes.txt' % (folder_name, dataset_name, output_folder, cohort),'w') as f:
            for gene in sorted(list(set(found_entrez))):
                f.write('%s\n' % gene) 

#### Subset patients

In [None]:
if analyze_shap_values:

    # iterate over patients
    for sample in samples:
        
        # get features with non-negative shap_values
        sig_features = [feature for feature in merged_features if shap_value_abs.at[sample,feature]>0]
        sig_values = [shap_value_abs.at[sample,x] for x in sig_features]
        sort_index = np.argsort(sig_values)[::-1]
        sig_features = [sig_features[i] for i in sort_index]
        sig_values = [sig_values[i] for i in sort_index]
        sig_values_cumsum = np.cumsum(sig_values)
        sig_values_cumsum /= np.max(sig_values_cumsum)
        
        # top N features - mean
        N = len([x for x in sig_values_cumsum if x<=0.95])
        top_N_features = [x.split(' # ')[-1] for x in sig_features[:N]]
        
        # gene lists for gene ontology - significant genes
        found_entrez = []
        stepsize = 400
        for i in range(int(np.ceil(len(top_N_features)/stepsize))):
            time.sleep(1)
            if i < len(top_N_features)/stepsize:
                gene_list = ','.join(top_N_features[(stepsize*i):(stepsize*(i+1))])
            else:
                gene_list = ','.join(top_N_features[(stepsize*i):])
            link = 'https://biodbnet-abcc.ncifcrf.gov/webServices/rest.php/biodbnetRestApi.json?method=db2db&input=genesymbolandsynonyms&inputValues=%s&outputs=geneid&taxonId=9606&format=row' % gene_list
            data = json.loads(urllib.request.urlopen(link).read())  
            for j in range(len(data)):
                if (data[j]['Gene ID'] != '-') and ('//' not in data[j]['Gene ID']):
                    found_entrez.append(str(data[j]['Gene ID']))
        with open('%s/%s/%s/shap_values/summary/patient/%s_sig_genes.txt' % (folder_name, dataset_name, output_folder, sample),'w') as f:
            for gene in sorted(list(set(found_entrez))):
                f.write('%s\n' % gene) 

#### Patient shap plot

In [None]:
# plotting function
def plot_patient_example(patient_name, actual_value, expected_value, features_, shap_, original_):
    
    # initialize figure
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(20,.15*len(features_)))
    plt.xticks([])
    plt.yticks([])
    
    # patient name
    if actual_value == 0:
        plt.text(0.5,0.8,patient_name, horizontalalignment='center', verticalalignment='center', color='#5bc2ae', fontsize=18, weight='bold')
    else:
        plt.text(0.5,1,patient_name, horizontalalignment='center', verticalalignment='center', color='#d93f20', fontsize=18, weight='bold')
    
    # number line
    plt.plot([0.18,0.23],[0,0],'k-',linewidth=1)
    plt.plot([0.25,1],[0,0],'k-',linewidth=1)
    plt.plot([0.22,0.24],[-0.3,0.3],'k-',linewidth=1)
    plt.plot([0.24,0.26],[-0.3,0.3],'k-',linewidth=1)
    for x in [0.18,0.5,1]:
        plt.plot([x,x],[-0.1,0.1],'k-',linewidth=1)
    plt.plot([expected_value,expected_value],[-0.3,0.1],'k--',linewidth=1)
    plt.text(0.18,0.1,'0%', horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16)
    plt.text(0.5,0.1,'50%', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(1,0.1,'100%', horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16)
    plt.text(expected_value,0.1,'%0.1f%%' % (expected_value*100), horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(expected_value,0.75,'Prior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    
    # data
    current_value = expected_value
    y_value = -0.3
    for i in range(len(features_)):
        if not pd.isna(original_[i]):
            if shap_[i] < 0:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#5bc2ae')
                if type(original_[i]) == str:
                    plt.text(current_value+0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='left', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value+0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='left', verticalalignment='center', fontsize=10)
            else:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#d93f20')
                if type(original_[i]) == str:
                    plt.text(current_value-0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='right', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value-0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=10)
            current_value += shap_[i]
            y_value += -0.3
        
    # end line
    plt.plot([current_value,current_value],[y_value+0.15,0.1],'k--',linewidth=1)
    if current_value < 0.5:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16, weight='bold')
    else:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16, weight='bold')
    plt.text(current_value,0.75,'Posterior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
        
    # limits
    plt.xlim(-0.02,1.02)
    plt.ylim(y_value-0.05, 0.8)
    plt.axis('off')

In [None]:
# example: TCGA-DU-8165-01A
for a,sample in enumerate(samples):
    if sample=='TCGA-DU-8165-01A':
    
        # get nonzero shap and feature values
        features_ = [feature.split(' # ')[-1] for feature in merged_features if shap_value.at[sample,feature] != 0]
        shap_ = shap_value.loc[sample][['gene_all # %s' % x for x in features_]].tolist()
        original_ = X_matrix.loc[sample][['gene_all # %s' % x for x in features_]].tolist()

        # convert numerical values to strings
        for j in range(len(original_)):
            if not type(original_[j]) == str:
                if type(original_[j]) == bool:
                    original_[j] = str(original_[j])
                elif not np.isnan(original_[j]):
                    if original_[j] == int(original_[j]):
                        original_[j] = str(int(original_[j]))
                    else:
                        original_[j] = str(np.round(original_[j],2))+' TPM'

        # order features based on absolute shap value
        sort_index = np.argsort(np.abs(shap_))[::-1]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        # create figure
        plot_patient_example(sample, y_vector[a], expected_value[a], features_, shap_, original_)
        plt.savefig('%s/%s/%s/shap_values/patient_example.png' % (folder_name, dataset_name, output_folder), bbox_inches='tight', dpi=400)
        plt.close()

In [None]:
# plotting function
def plot_patient(patient_name, actual_value, expected_value, features_, shap_, original_):
    
    # initialize figure
    plt.rcParams['font.family'] = 'Arial'
    fig = plt.figure(figsize=(20,.15*len(features_)))
    plt.xticks([])
    plt.yticks([])
    
    # patient name
    if actual_value == 0:
        plt.text(0.5,0.8,patient_name, horizontalalignment='center', verticalalignment='center', color='#5bc2ae', fontsize=18, weight='bold')
    else:
        plt.text(0.5,1,patient_name, horizontalalignment='center', verticalalignment='center', color='#d93f20', fontsize=18, weight='bold')
    
    # number line
    plt.plot([0,1],[0,0],'k-',linewidth=1)
    for x in [0,0.5,1]:
        plt.plot([x,x],[-0.1,0.1],'k-',linewidth=1)
    plt.plot([expected_value,expected_value],[-0.3,0.1],'k--',linewidth=1)
    plt.text(0,0.1,'0%', horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16)
    plt.text(0.5,0.1,'50%', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(1,0.1,'100%', horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16)
    plt.text(expected_value,0.1,'%0.1f%%' % (expected_value*100), horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    plt.text(expected_value,0.75,'Prior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
    
    # data
    current_value = expected_value
    y_value = -0.3
    for i in range(len(features_)):
        if not pd.isna(original_[i]):
            if shap_[i] < 0:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#5bc2ae')
                if type(original_[i]) == str:
                    plt.text(current_value+0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='left', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value+0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='left', verticalalignment='center', fontsize=10)
            else:
                plt.arrow(current_value, y_value, shap_[i], 0, width=0.05, head_length= 0.2*np.abs(shap_[i]), length_includes_head=True, color='#d93f20')
                if type(original_[i]) == str:
                    plt.text(current_value-0.005, y_value, '%s = %s' % (features_[i], original_[i]), horizontalalignment='right', verticalalignment='center', fontsize=10)
                else:
                    plt.text(current_value-0.005, y_value, '%s (Imputed)' % features_[i], horizontalalignment='right', verticalalignment='center', fontsize=10)
            current_value += shap_[i]
            y_value += -0.3
        
    # end line
    plt.plot([current_value,current_value],[y_value+0.15,0.1],'k--',linewidth=1)
    if current_value < 0.5:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#5bc2ae', fontsize=16, weight='bold')
    else:
        plt.text(current_value,0.1,'%0.1f%%' % (current_value*100), horizontalalignment='center', verticalalignment='bottom', color='#d93f20', fontsize=16, weight='bold')
    plt.text(current_value,0.75,'Posterior', horizontalalignment='center', verticalalignment='bottom', color='black', fontsize=16)
        
    # limits
    plt.xlim(-0.02,1.02)
    plt.ylim(y_value-0.05, 0.8)
    plt.axis('off')

In [None]:
# iterate over patients
for a,sample in enumerate(samples):
    if y_vector[a]==1:
    
        # get nonzero shap and feature values
        features_ = [feature.split(' # ')[-1] for feature in merged_features if shap_value.at[sample,feature] != 0]
        shap_ = shap_value.loc[sample][['gene_all # %s' % x for x in features_]].tolist()
        original_ = X_matrix.loc[sample][['gene_all # %s' % x for x in features_]].tolist()

        # convert numerical values to strings
        for j in range(len(original_)):
            if not type(original_[j]) == str:
                if type(original_[j]) == bool:
                    original_[j] = str(original_[j])
                elif not np.isnan(original_[j]):
                    if original_[j] == int(original_[j]):
                        original_[j] = str(int(original_[j]))
                    else:
                        original_[j] = str(original_[j])             

        # order features based on absolute shap value
        sort_index = np.argsort(np.abs(shap_))[::-1]
        features_ = [features_[i] for i in sort_index]
        shap_ = [shap_[i] for i in sort_index]
        original_ = [original_[i] for i in sort_index]

        # create figure
        plot_patient(sample, y_vector[a], expected_value[a], features_, shap_, original_)
        plt.savefig('%s/%s/%s/shap_values/patients/%s.png' % (folder_name, dataset_name, output_folder, sample), bbox_inches='tight', dpi=400)
        plt.close()